1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pybind_api/ir/hook_py.h"
18 #include <memory>
19 #include <string>
20 #include "include/common/utils/hook.h"
21
22 namespace mindspore {
23 namespace tensor {
24
25 namespace {
BuildAutoGradMeta(const tensor::Tensor & tensor)26 AutoGradMetaDataWeakPtr BuildAutoGradMeta(const tensor::Tensor &tensor) {
27 auto auto_grad_meta_data = tensor.auto_grad_meta_data();
28 if (auto_grad_meta_data == nullptr) {
29 auto_grad_meta_data = std::make_shared<AutoGradMetaData>();
30 const_cast<Tensor &>(tensor).set_auto_grad_meta_data(auto_grad_meta_data);
31 MS_LOG(DEBUG) << "Tensor has no auto_grad_meta_data, build it";
32 }
33 return {auto_grad_meta_data};
34 }
35
GetTensorNumId(const std::string & id)36 inline uint64_t GetTensorNumId(const std::string &id) { return std::stoull(id.substr(1)); }
37 } // namespace
38
39 std::map<uint64_t, std::pair<AutoGradMetaDataWeakPtr, TensorBackwardHookPtr>> RegisterHook::hook_meta_fn_map_ = {};
40
RegisterTensorBackwardHook(const Tensor & tensor,const py::function & hook)41 uint64_t RegisterHook::RegisterTensorBackwardHook(const Tensor &tensor, const py::function &hook) {
42 // Delete char 'T'
43 const auto &tensor_id = GetTensorNumId(tensor.id());
44 MS_LOG(DEBUG) << "Register hook " << py::str(py::cast<py::object>(hook)).cast<std::string>() << " for tensor "
45 << tensor.ToString() << " with id " << tensor_id;
46 auto meta = BuildAutoGradMeta(tensor);
47 MS_EXCEPTION_IF_NULL(meta.lock());
48 meta.lock()->ClearBackwardHooks();
49 auto tensor_backward_hook = std::make_shared<TensorBackwardHook>(tensor_id, hook);
50 meta.lock()->AddBackwardHook(tensor_id, tensor_backward_hook);
51 // Just keep last hook
52 hook_meta_fn_map_[tensor_id] = {meta, tensor_backward_hook};
53 return tensor_id;
54 }
55
RemoveTensorBackwardHook(uint64_t id)56 void RegisterHook::RemoveTensorBackwardHook(uint64_t id) {
57 const auto it = hook_meta_fn_map_.find(id);
58 if (it == hook_meta_fn_map_.end()) {
59 return;
60 }
61 auto meta = it->second.first.lock();
62 if (meta == nullptr) {
63 return;
64 }
65 MS_LOG(DEBUG) << "Remove hook by id " << id;
66 meta->RemoveBackwardHook(id);
67 }
68
UpdateTensorBackwardHook(const AutoGradMetaDataPtr & auto_grad_meta_data,const std::string & id)69 void RegisterHook::UpdateTensorBackwardHook(const AutoGradMetaDataPtr &auto_grad_meta_data, const std::string &id) {
70 MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
71 const auto &tensor_id = GetTensorNumId(id);
72 auto it = hook_meta_fn_map_.find(tensor_id);
73 if (it != hook_meta_fn_map_.end()) {
74 MS_LOG(DEBUG) << "Update tensor backward hook for tensor id " << id;
75 auto_grad_meta_data->AddBackwardHook(tensor_id, it->second.second);
76 // Update remove handle
77 hook_meta_fn_map_[tensor_id].first = std::weak_ptr<AutoGradMetaData>(auto_grad_meta_data);
78 }
79 }
80 } // namespace tensor
81 } // namespace mindspore
82