• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pybind_api/ir/hook_py.h"
18 #include <memory>
19 #include <string>
20 #include "include/common/utils/hook.h"
21 
22 namespace mindspore {
23 namespace tensor {
24 
25 namespace {
BuildAutoGradMeta(const tensor::Tensor & tensor)26 AutoGradMetaDataWeakPtr BuildAutoGradMeta(const tensor::Tensor &tensor) {
27   auto auto_grad_meta_data = tensor.auto_grad_meta_data();
28   if (auto_grad_meta_data == nullptr) {
29     auto_grad_meta_data = std::make_shared<AutoGradMetaData>();
30     const_cast<Tensor &>(tensor).set_auto_grad_meta_data(auto_grad_meta_data);
31     MS_LOG(DEBUG) << "Tensor has no auto_grad_meta_data, build it";
32   }
33   return {auto_grad_meta_data};
34 }
35 
GetTensorNumId(const std::string & id)36 inline uint64_t GetTensorNumId(const std::string &id) { return std::stoull(id.substr(1)); }
37 }  // namespace
38 
39 std::map<uint64_t, std::pair<AutoGradMetaDataWeakPtr, TensorBackwardHookPtr>> RegisterHook::hook_meta_fn_map_ = {};
40 
RegisterTensorBackwardHook(const Tensor & tensor,const py::function & hook)41 uint64_t RegisterHook::RegisterTensorBackwardHook(const Tensor &tensor, const py::function &hook) {
42   // Delete char 'T'
43   const auto &tensor_id = GetTensorNumId(tensor.id());
44   MS_LOG(DEBUG) << "Register hook " << py::str(py::cast<py::object>(hook)).cast<std::string>() << " for tensor "
45                 << tensor.ToString() << " with id " << tensor_id;
46   auto meta = BuildAutoGradMeta(tensor);
47   MS_EXCEPTION_IF_NULL(meta.lock());
48   meta.lock()->ClearBackwardHooks();
49   auto tensor_backward_hook = std::make_shared<TensorBackwardHook>(tensor_id, hook);
50   meta.lock()->AddBackwardHook(tensor_id, tensor_backward_hook);
51   // Just keep last hook
52   hook_meta_fn_map_[tensor_id] = {meta, tensor_backward_hook};
53   return tensor_id;
54 }
55 
RemoveTensorBackwardHook(uint64_t id)56 void RegisterHook::RemoveTensorBackwardHook(uint64_t id) {
57   const auto it = hook_meta_fn_map_.find(id);
58   if (it == hook_meta_fn_map_.end()) {
59     return;
60   }
61   auto meta = it->second.first.lock();
62   if (meta == nullptr) {
63     return;
64   }
65   MS_LOG(DEBUG) << "Remove hook by id " << id;
66   meta->RemoveBackwardHook(id);
67 }
68 
UpdateTensorBackwardHook(const AutoGradMetaDataPtr & auto_grad_meta_data,const std::string & id)69 void RegisterHook::UpdateTensorBackwardHook(const AutoGradMetaDataPtr &auto_grad_meta_data, const std::string &id) {
70   MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
71   const auto &tensor_id = GetTensorNumId(id);
72   auto it = hook_meta_fn_map_.find(tensor_id);
73   if (it != hook_meta_fn_map_.end()) {
74     MS_LOG(DEBUG) << "Update tensor backward hook for tensor id " << id;
75     auto_grad_meta_data->AddBackwardHook(tensor_id, it->second.second);
76     // Update remove handle
77     hook_meta_fn_map_[tensor_id].first = std::weak_ptr<AutoGradMetaData>(auto_grad_meta_data);
78   }
79 }
80 }  // namespace tensor
81 }  // namespace mindspore
82