• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/actor/control_flow/switch_actor.h"
18 #include "runtime/graph_scheduler/actor/control_flow/entrance_actor.h"
19 #include "plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
20 #include "abstract/utils.h"
21 #include "runtime/graph_scheduler/actor/output_actor.h"
22 #include "utils/log_adapter.h"
23 #include "include/common/utils/python_adapter.h"
24 
25 namespace mindspore {
26 namespace runtime {
27 constexpr size_t kMaxSwitchCondSize = 8;
28 constexpr size_t kSwitchDefaultOutputNum = 1;
29 
SwitchActor(const std::string & name,const AID & memory_manager_aid,const std::vector<KernelWithIndex> & parameters,const AnfNodePtr & node)30 SwitchActor::SwitchActor(const std::string &name, const AID &memory_manager_aid,
31                          const std::vector<KernelWithIndex> &parameters, const AnfNodePtr &node)
32     : ControlActor(name, KernelTransformType::kSwitchActor, memory_manager_aid, parameters, node) {
33   device_contexts_.resize(parameters.size());
34   output_data_by_output_index_.resize(kSwitchDefaultOutputNum);
35 }
36 
FetchInput(OpContext<DeviceTensor> * const context)37 void SwitchActor::FetchInput(OpContext<DeviceTensor> *const context) {
38   MS_EXCEPTION_IF_NULL(context);
39 
40   // Call the base class interface to get input data and input partial.
41   ControlActor::FetchInput(context);
42 
43   ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
44   size_t index = GetIndex(context);
45   if (common::IsNeedProfileMemory()) {
46     // dry run switch index is always 0.
47     index = input_partials_.size() - kSwitchCondPos - 1;
48   }
49   if (!output_partial_arrows_.empty()) {
50     if (index + kSwitchCondPos >= input_partials_.size()) {
51       MS_EXCEPTION(IndexError) << "Given index " << std::to_string(index)
52                                << " out of range. Please make sure the value of index in ["
53                                << std::to_string(1 - SizeToInt(input_partials_.size())) << ", "
54                                << std::to_string(input_partials_.size() - 1) + "), and the type is int32.";
55     }
56     MS_EXCEPTION_IF_NULL(input_partials_[index + kSwitchCondPos]);
57     auto func_graph = input_partials_[index + kSwitchCondPos]->func_graph_;
58     MS_EXCEPTION_IF_NULL(func_graph);
59     input_partials_[0] = input_partials_[index + kSwitchCondPos];
60   }
61 
62   for (auto &output_data : output_data_by_output_index_[0]) {
63     MS_EXCEPTION_IF_NULL(output_data);
64     MS_EXCEPTION_IF_NULL(input_device_tensors_[index + kSwitchCondPos]);
65     output_data->data_ = input_device_tensors_[index + kSwitchCondPos];
66   }
67 }
68 
GetIndex(const OpContext<DeviceTensor> * const context) const69 size_t SwitchActor::GetIndex(const OpContext<DeviceTensor> *const context) const {
70   MS_EXCEPTION_IF_NULL(context);
71   MS_EXCEPTION_IF_NULL(input_device_tensors_[0]);
72 
73   DeviceTensor *device_tensor = input_device_tensors_[0];
74   TypeId type_id = device_tensor->type_id();
75   size_t size = abstract::TypeIdSize(type_id);
76   if (size > sizeof(int64_t)) {
77     MS_LOG(ERROR) << "Index must be Int type.";
78     return 0;
79   }
80 
81   int64_t index = 0;
82   char buf[kMaxSwitchCondSize] = {0};
83   ShapeVector host_shape;
84   if (device_tensor->user_data() != nullptr && device_tensor->need_sync_user_data() &&
85       device_tensor->user_data()->has(kernel::PyExecuteOutputUserData::key)) {
86     const auto &user_data_obj =
87       device_tensor->user_data()->get<kernel::PyExecuteOutputUserData>(kernel::PyExecuteOutputUserData::key);
88     MS_EXCEPTION_IF_NULL(user_data_obj);
89     const auto &obj = user_data_obj->obj;
90     py::gil_scoped_acquire gil_acquire;
91     if (py::isinstance<py::bool_>(obj)) {
92       MS_LOG(DEBUG) << "Index:" << py::cast<bool>(obj) << " for actor:" << GetAID();
93       return index = static_cast<int64_t>(py::cast<bool>(obj) ? 1 : 0);
94     }
95   }
96   if (!device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf))) {
97     MS_LOG(ERROR) << GetAID().Name() << " get index from device address failed, type id:" << type_id;
98     return 0;
99   }
100 
101   if (type_id == TypeId::kNumberTypeInt32) {
102     index = static_cast<int64_t>((static_cast<int32_t *>(static_cast<void *>(buf)))[0]);
103     MS_LOG(DEBUG) << "Index:" << index << " for actor:" << GetAID();
104   } else if (type_id == TypeId::kNumberTypeInt64) {
105     index = (static_cast<int64_t *>(static_cast<void *>(buf)))[0];
106     MS_LOG(DEBUG) << "Index:" << index << " for actor:" << GetAID();
107   } else if (type_id == TypeId::kNumberTypeBool) {
108     bool cond = (static_cast<bool *>(static_cast<void *>(buf)))[0];
109     if (cond) {
110       index = 1;
111     }
112     MS_LOG(DEBUG) << "Condition:" << cond << ", index:" << index << " for actor:" << GetAID();
113   } else {
114     MS_LOG(ERROR) << "Index must be Int type.";
115     return 0;
116   }
117 
118   // SwitchLayer node support negative index range [-size, -1].
119   if (index < 0) {
120     int64_t positive_index = index + SizeToLong(formal_parameters_.size() - 1);
121     if (positive_index < 0) {
122       MS_EXCEPTION(IndexError) << "Given index " << std::to_string(index)
123                                << " out of range. Please make sure the value of index in ["
124                                << std::to_string(1 - SizeToInt(input_partials_.size())) << ", "
125                                << std::to_string(input_partials_.size() - 1) + "), and the type is int32.";
126     }
127     index = positive_index;
128   }
129   return LongToSize(index);
130 }
131 }  // namespace runtime
132 }  // namespace mindspore
133