1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/actor/control_flow/switch_actor.h"
18 #include "runtime/graph_scheduler/actor/control_flow/entrance_actor.h"
19 #include "plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
20 #include "abstract/utils.h"
21 #include "runtime/graph_scheduler/actor/output_actor.h"
22 #include "utils/log_adapter.h"
23 #include "include/common/utils/python_adapter.h"
24
25 namespace mindspore {
26 namespace runtime {
27 constexpr size_t kMaxSwitchCondSize = 8;
28 constexpr size_t kSwitchDefaultOutputNum = 1;
29
SwitchActor(const std::string & name,const AID & memory_manager_aid,const std::vector<KernelWithIndex> & parameters,const AnfNodePtr & node)30 SwitchActor::SwitchActor(const std::string &name, const AID &memory_manager_aid,
31 const std::vector<KernelWithIndex> ¶meters, const AnfNodePtr &node)
32 : ControlActor(name, KernelTransformType::kSwitchActor, memory_manager_aid, parameters, node) {
33 device_contexts_.resize(parameters.size());
34 output_data_by_output_index_.resize(kSwitchDefaultOutputNum);
35 }
36
FetchInput(OpContext<DeviceTensor> * const context)37 void SwitchActor::FetchInput(OpContext<DeviceTensor> *const context) {
38 MS_EXCEPTION_IF_NULL(context);
39
40 // Call the base class interface to get input data and input partial.
41 ControlActor::FetchInput(context);
42
43 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
44 size_t index = GetIndex(context);
45 if (common::IsNeedProfileMemory()) {
46 // dry run switch index is always 0.
47 index = input_partials_.size() - kSwitchCondPos - 1;
48 }
49 if (!output_partial_arrows_.empty()) {
50 if (index + kSwitchCondPos >= input_partials_.size()) {
51 MS_EXCEPTION(IndexError) << "Given index " << std::to_string(index)
52 << " out of range. Please make sure the value of index in ["
53 << std::to_string(1 - SizeToInt(input_partials_.size())) << ", "
54 << std::to_string(input_partials_.size() - 1) + "), and the type is int32.";
55 }
56 MS_EXCEPTION_IF_NULL(input_partials_[index + kSwitchCondPos]);
57 auto func_graph = input_partials_[index + kSwitchCondPos]->func_graph_;
58 MS_EXCEPTION_IF_NULL(func_graph);
59 input_partials_[0] = input_partials_[index + kSwitchCondPos];
60 }
61
62 for (auto &output_data : output_data_by_output_index_[0]) {
63 MS_EXCEPTION_IF_NULL(output_data);
64 MS_EXCEPTION_IF_NULL(input_device_tensors_[index + kSwitchCondPos]);
65 output_data->data_ = input_device_tensors_[index + kSwitchCondPos];
66 }
67 }
68
GetIndex(const OpContext<DeviceTensor> * const context) const69 size_t SwitchActor::GetIndex(const OpContext<DeviceTensor> *const context) const {
70 MS_EXCEPTION_IF_NULL(context);
71 MS_EXCEPTION_IF_NULL(input_device_tensors_[0]);
72
73 DeviceTensor *device_tensor = input_device_tensors_[0];
74 TypeId type_id = device_tensor->type_id();
75 size_t size = abstract::TypeIdSize(type_id);
76 if (size > sizeof(int64_t)) {
77 MS_LOG(ERROR) << "Index must be Int type.";
78 return 0;
79 }
80
81 int64_t index = 0;
82 char buf[kMaxSwitchCondSize] = {0};
83 ShapeVector host_shape;
84 if (device_tensor->user_data() != nullptr && device_tensor->need_sync_user_data() &&
85 device_tensor->user_data()->has(kernel::PyExecuteOutputUserData::key)) {
86 const auto &user_data_obj =
87 device_tensor->user_data()->get<kernel::PyExecuteOutputUserData>(kernel::PyExecuteOutputUserData::key);
88 MS_EXCEPTION_IF_NULL(user_data_obj);
89 const auto &obj = user_data_obj->obj;
90 py::gil_scoped_acquire gil_acquire;
91 if (py::isinstance<py::bool_>(obj)) {
92 MS_LOG(DEBUG) << "Index:" << py::cast<bool>(obj) << " for actor:" << GetAID();
93 return index = static_cast<int64_t>(py::cast<bool>(obj) ? 1 : 0);
94 }
95 }
96 if (!device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf))) {
97 MS_LOG(ERROR) << GetAID().Name() << " get index from device address failed, type id:" << type_id;
98 return 0;
99 }
100
101 if (type_id == TypeId::kNumberTypeInt32) {
102 index = static_cast<int64_t>((static_cast<int32_t *>(static_cast<void *>(buf)))[0]);
103 MS_LOG(DEBUG) << "Index:" << index << " for actor:" << GetAID();
104 } else if (type_id == TypeId::kNumberTypeInt64) {
105 index = (static_cast<int64_t *>(static_cast<void *>(buf)))[0];
106 MS_LOG(DEBUG) << "Index:" << index << " for actor:" << GetAID();
107 } else if (type_id == TypeId::kNumberTypeBool) {
108 bool cond = (static_cast<bool *>(static_cast<void *>(buf)))[0];
109 if (cond) {
110 index = 1;
111 }
112 MS_LOG(DEBUG) << "Condition:" << cond << ", index:" << index << " for actor:" << GetAID();
113 } else {
114 MS_LOG(ERROR) << "Index must be Int type.";
115 return 0;
116 }
117
118 // SwitchLayer node support negative index range [-size, -1].
119 if (index < 0) {
120 int64_t positive_index = index + SizeToLong(formal_parameters_.size() - 1);
121 if (positive_index < 0) {
122 MS_EXCEPTION(IndexError) << "Given index " << std::to_string(index)
123 << " out of range. Please make sure the value of index in ["
124 << std::to_string(1 - SizeToInt(input_partials_.size())) << ", "
125 << std::to_string(input_partials_.size() - 1) + "), and the type is int32.";
126 }
127 index = positive_index;
128 }
129 return LongToSize(index);
130 }
131 } // namespace runtime
132 } // namespace mindspore
133