1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/api/cell.h"
17 #include "include/api/context.h"
18 #include "cxx_api/factory.h"
19 #include "cxx_api/graph/graph_impl.h"
20
21 namespace mindspore {
operator ()(const std::vector<Input> & inputs) const22 std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); }
23
ParameterCell(const ParameterCell & cell)24 ParameterCell::ParameterCell(const ParameterCell &cell) {
25 auto tmp_ptr = cell.tensor_.Clone();
26 tensor_ = *tmp_ptr;
27 MSTensor::DestroyTensorPtr(tmp_ptr);
28 }
29
operator =(const ParameterCell & cell)30 ParameterCell &ParameterCell::operator=(const ParameterCell &cell) {
31 if (&cell == this) {
32 return *this;
33 }
34 auto tmp_ptr = cell.tensor_.Clone();
35 tensor_ = *tmp_ptr;
36 MSTensor::DestroyTensorPtr(tmp_ptr);
37 return *this;
38 }
39
ParameterCell(ParameterCell && cell)40 ParameterCell::ParameterCell(ParameterCell &&cell) : tensor_(cell.tensor_) {}
41
operator =(ParameterCell && cell)42 ParameterCell &ParameterCell::operator=(ParameterCell &&cell) {
43 if (&cell == this) {
44 return *this;
45 }
46 tensor_ = cell.tensor_;
47 return *this;
48 }
49
ParameterCell(const MSTensor & tensor)50 ParameterCell::ParameterCell(const MSTensor &tensor) {
51 auto tmp_ptr = tensor.Clone();
52 tensor_ = *tmp_ptr;
53 MSTensor::DestroyTensorPtr(tmp_ptr);
54 }
55
operator =(const MSTensor & tensor)56 ParameterCell &ParameterCell::operator=(const MSTensor &tensor) {
57 auto tmp_ptr = tensor.Clone();
58 tensor_ = *tmp_ptr;
59 MSTensor::DestroyTensorPtr(tmp_ptr);
60 return *this;
61 }
62
ParameterCell(MSTensor && tensor)63 ParameterCell::ParameterCell(MSTensor &&tensor) : tensor_(tensor) {}
64
operator =(MSTensor && tensor)65 ParameterCell &ParameterCell::operator=(MSTensor &&tensor) {
66 tensor_ = tensor;
67 return *this;
68 }
69
GraphCell(const Graph & graph)70 GraphCell::GraphCell(const Graph &graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); }
71
GraphCell(const std::shared_ptr<Graph> & graph)72 GraphCell::GraphCell(const std::shared_ptr<Graph> &graph) : graph_(graph) { MS_EXCEPTION_IF_NULL(graph_); }
73
GraphCell(Graph && graph)74 GraphCell::GraphCell(Graph &&graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); }
75
SetContext(const std::shared_ptr<Context> & context)76 void GraphCell::SetContext(const std::shared_ptr<Context> &context) {
77 if (executor_ == nullptr) {
78 executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
79 if (executor_ == nullptr) {
80 MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
81 return;
82 }
83 executor_->SetGraph(graph_);
84 }
85 executor_->SetContext(context);
86 }
87
Run(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs)88 Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
89 if (executor_ == nullptr) {
90 executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
91 if (executor_ == nullptr) {
92 MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
93 return kMEFailed;
94 }
95 executor_->SetGraph(graph_);
96 }
97 return executor_->Run(inputs, outputs);
98 }
99
Load(uint32_t device_id)100 Status GraphCell::Load(uint32_t device_id) {
101 if (executor_ == nullptr) {
102 executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
103 if (executor_ == nullptr) {
104 MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
105 return kMEFailed;
106 }
107 executor_->SetGraph(graph_);
108 }
109 return executor_->Load(device_id);
110 }
111
GetInputs()112 std::vector<MSTensor> GraphCell::GetInputs() {
113 if (executor_ == nullptr) {
114 executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
115 if (executor_ == nullptr) {
116 MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
117 return {};
118 }
119 executor_->SetGraph(graph_);
120 }
121 return executor_->GetInputs();
122 }
123
GetOutputs()124 std::vector<MSTensor> GraphCell::GetOutputs() {
125 if (executor_ == nullptr) {
126 executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
127 if (executor_ == nullptr) {
128 MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
129 return {};
130 }
131 executor_->SetGraph(graph_);
132 }
133 return executor_->GetOutputs();
134 }
135
InputAndOutput()136 InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {}
137
InputAndOutput(const MSTensor & tensor)138 InputAndOutput::InputAndOutput(const MSTensor &tensor) : prev_(), index_(-1) {
139 auto tmp_ptr = tensor.Clone();
140 cell_ = std::make_shared<ParameterCell>(*tmp_ptr);
141 MSTensor::DestroyTensorPtr(tmp_ptr);
142 }
InputAndOutput(MSTensor && tensor)143 InputAndOutput::InputAndOutput(MSTensor &&tensor)
144 : cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {}
145
InputAndOutput(const std::shared_ptr<CellBase> & cell,const std::vector<InputAndOutput> & prev,int32_t index)146 InputAndOutput::InputAndOutput(const std::shared_ptr<CellBase> &cell, const std::vector<InputAndOutput> &prev,
147 int32_t index)
148 : cell_(cell), prev_(prev), index_(index) {}
149 } // namespace mindspore
150