1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/framework/actor/output_actor.h"
18 #include "utils/log_adapter.h"
19
20 namespace mindspore {
21 namespace runtime {
22 namespace {
CreateOutputTensor(const AnfNodePtr & output_node,size_t output_index,size_t output_position)23 TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position) {
24 MS_EXCEPTION_IF_NULL(output_node);
25 MS_LOG(INFO) << "Create output tensor, output node: " << output_node->fullname_with_scope()
26 << ", output index: " << output_index << ", output position: " << output_position;
27
28 // Create host tensor, the output tensor should use the infer type, it will be handed correctly by tensor data sync
29 // when infer type is not equal to device type.
30 auto type_id = AnfAlgo::GetOutputInferDataType(output_node, output_index);
31 std::vector<int64_t> temp_shape;
32 auto shape = AnfAlgo::GetOutputInferShape(output_node, output_index);
33 (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
34 auto tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
35 tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(output_node, output_index));
36
37 // Put device tensor into host tensor.
38 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
39 tensor->set_device_address(device_tensor);
40
41 return tensor;
42 }
43 } // namespace
44
Init()45 void OutputActor::Init() {
46 // Set the number of actor running dependent messages.
47 if ((!need_loop_count_)) {
48 running_dependent_msg_num_ = SizeToInt(outputs_num_ - device_tensor_store_keys_.size());
49 }
50 }
51
CollectLoopCount(size_t loop_count,OpContext<DeviceTensor> * const context)52 void OutputActor::CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *const context) {
53 MS_EXCEPTION_IF_NULL(context);
54
55 current_count_ = loop_count;
56 if (loop_count_ == current_count_) {
57 if (current_outputs_num_ + device_tensor_store_keys_.size() != outputs_num_) {
58 std::string error_info = "The outputs num is wrong, the total outputs num: " + std::to_string(outputs_num_) +
59 ", the current outputs num: " + std::to_string(current_outputs_num_) +
60 ", the device tensor store num: " + std::to_string(device_tensor_store_keys_.size());
61 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
62 }
63
64 // Because device tensor store can't send data, so fetch the output result of device tensor store in running end.
65 for (const auto &device_tensor_store_key : device_tensor_store_keys_) {
66 if (device_tensor_store_key.first >= outputs_.size()) {
67 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The input index is of range.");
68 }
69 outputs_[device_tensor_store_key.first] =
70 CreateOutputTensor(device_tensor_store_key.second, 0, device_tensor_store_key.first);
71 }
72
73 current_outputs_num_ = 0;
74 current_count_ = 0;
75 SET_OPCONTEXT_SUCCESS_RET((*context));
76 }
77 }
78
UpdateOutputDeviceAddress()79 void OutputActor::UpdateOutputDeviceAddress() {
80 // In the running end, when the device tensor of graph output node is set into host tensor, the graph output node
81 // need be set new device tensor, to avoid that the device tensor context of host tensor be rewritten in the next
82 // step or next loop. But the graph output nodes corresponding to device tensor store need to be skipped, because
83 // they are fixed addresses and persistent.
84 for (size_t i = 0; i < output_nodes_.size(); ++i) {
85 auto &output_node = output_nodes_[i].first;
86 auto output_index = output_nodes_[i].second;
87 if ((output_node != nullptr) && (!IsPersistentDeviceTensor(output_node))) {
88 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
89 // The outputs may have the same output node, so need skip when the node has been set to new device tensor.
90 if ((device_tensor == nullptr) || (device_tensor->GetPtr() == nullptr)) {
91 continue;
92 }
93 const auto &device_context = device_contexts_[i];
94 MS_EXCEPTION_IF_NULL(device_context);
95 auto new_device_tensor = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(),
96 device_tensor->format(), device_tensor->type_id());
97 MS_EXCEPTION_IF_NULL(new_device_tensor);
98 new_device_tensor->set_original_ref_count(device_tensor->original_ref_count());
99 new_device_tensor->ResetRefCount();
100 AnfAlgo::SetOutputAddr(new_device_tensor, output_index, output_node.get());
101 }
102 }
103
104 output_nodes_.clear();
105 output_nodes_.resize(outputs_num_);
106 }
107
CollectOutput(const AnfNodePtr & output_node,size_t output_index,size_t output_position,OpContext<DeviceTensor> * const context)108 void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
109 OpContext<DeviceTensor> *const context) {
110 MS_EXCEPTION_IF_NULL(output_node);
111 MS_EXCEPTION_IF_NULL(context);
112 // Collect the output result in the last loop which is represented by "loop_count_ - current_count_ == 1".
113 if (loop_count_ - current_count_ != 1) {
114 return;
115 }
116
117 if (output_position >= outputs_.size()) {
118 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The input index is of range.");
119 }
120
121 auto tensor = CreateOutputTensor(output_node, output_index, output_position);
122 MS_EXCEPTION_IF_NULL(tensor);
123 tensor->set_need_release_device_mem(true);
124 outputs_[output_position] = tensor;
125 current_outputs_num_++;
126
127 // Save the output nodes to clear the device tensor in the running end.
128 output_nodes_[output_position] = KernelWithIndex(output_node, output_index);
129
130 // There is no loop count actor in step mode, need trigger call CollectLoopCount to replace old output device tensors.
131 if (!need_loop_count_ && (current_outputs_num_ + device_tensor_store_keys_.size() == outputs_num_)) {
132 CollectLoopCount(++current_count_, context);
133 }
134 }
135 } // namespace runtime
136 } // namespace mindspore
137