• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/api/cell.h"
17 #include "include/api/context.h"
18 #include "cxx_api/factory.h"
19 #include "cxx_api/graph/graph_impl.h"
20 
21 namespace mindspore {
operator ()(const std::vector<Input> & inputs) const22 std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); }
23 
ParameterCell(const ParameterCell & cell)24 ParameterCell::ParameterCell(const ParameterCell &cell) {
25   auto tmp_ptr = cell.tensor_.Clone();
26   tensor_ = *tmp_ptr;
27   MSTensor::DestroyTensorPtr(tmp_ptr);
28 }
29 
operator =(const ParameterCell & cell)30 ParameterCell &ParameterCell::operator=(const ParameterCell &cell) {
31   if (&cell == this) {
32     return *this;
33   }
34   auto tmp_ptr = cell.tensor_.Clone();
35   tensor_ = *tmp_ptr;
36   MSTensor::DestroyTensorPtr(tmp_ptr);
37   return *this;
38 }
39 
ParameterCell(ParameterCell && cell)40 ParameterCell::ParameterCell(ParameterCell &&cell) : tensor_(cell.tensor_) {}
41 
operator =(ParameterCell && cell)42 ParameterCell &ParameterCell::operator=(ParameterCell &&cell) {
43   if (&cell == this) {
44     return *this;
45   }
46   tensor_ = cell.tensor_;
47   return *this;
48 }
49 
ParameterCell(const MSTensor & tensor)50 ParameterCell::ParameterCell(const MSTensor &tensor) {
51   auto tmp_ptr = tensor.Clone();
52   tensor_ = *tmp_ptr;
53   MSTensor::DestroyTensorPtr(tmp_ptr);
54 }
55 
operator =(const MSTensor & tensor)56 ParameterCell &ParameterCell::operator=(const MSTensor &tensor) {
57   auto tmp_ptr = tensor.Clone();
58   tensor_ = *tmp_ptr;
59   MSTensor::DestroyTensorPtr(tmp_ptr);
60   return *this;
61 }
62 
ParameterCell(MSTensor && tensor)63 ParameterCell::ParameterCell(MSTensor &&tensor) : tensor_(tensor) {}
64 
operator =(MSTensor && tensor)65 ParameterCell &ParameterCell::operator=(MSTensor &&tensor) {
66   tensor_ = tensor;
67   return *this;
68 }
69 
GraphCell(const Graph & graph)70 GraphCell::GraphCell(const Graph &graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); }
71 
GraphCell(const std::shared_ptr<Graph> & graph)72 GraphCell::GraphCell(const std::shared_ptr<Graph> &graph) : graph_(graph) { MS_EXCEPTION_IF_NULL(graph_); }
73 
GraphCell(Graph && graph)74 GraphCell::GraphCell(Graph &&graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); }
75 
SetContext(const std::shared_ptr<Context> & context)76 void GraphCell::SetContext(const std::shared_ptr<Context> &context) {
77   if (executor_ == nullptr) {
78     executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
79     if (executor_ == nullptr) {
80       MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
81       return;
82     }
83     executor_->SetGraph(graph_);
84   }
85   executor_->SetContext(context);
86 }
87 
Run(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs)88 Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
89   if (executor_ == nullptr) {
90     executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
91     if (executor_ == nullptr) {
92       MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
93       return kMEFailed;
94     }
95     executor_->SetGraph(graph_);
96   }
97   return executor_->Run(inputs, outputs);
98 }
99 
Load(uint32_t device_id)100 Status GraphCell::Load(uint32_t device_id) {
101   if (executor_ == nullptr) {
102     executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
103     if (executor_ == nullptr) {
104       MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
105       return kMEFailed;
106     }
107     executor_->SetGraph(graph_);
108   }
109   return executor_->Load(device_id);
110 }
111 
GetInputs()112 std::vector<MSTensor> GraphCell::GetInputs() {
113   if (executor_ == nullptr) {
114     executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
115     if (executor_ == nullptr) {
116       MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
117       return {};
118     }
119     executor_->SetGraph(graph_);
120   }
121   return executor_->GetInputs();
122 }
123 
GetOutputs()124 std::vector<MSTensor> GraphCell::GetOutputs() {
125   if (executor_ == nullptr) {
126     executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
127     if (executor_ == nullptr) {
128       MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
129       return {};
130     }
131     executor_->SetGraph(graph_);
132   }
133   return executor_->GetOutputs();
134 }
135 
InputAndOutput()136 InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {}
137 
InputAndOutput(const MSTensor & tensor)138 InputAndOutput::InputAndOutput(const MSTensor &tensor) : prev_(), index_(-1) {
139   auto tmp_ptr = tensor.Clone();
140   cell_ = std::make_shared<ParameterCell>(*tmp_ptr);
141   MSTensor::DestroyTensorPtr(tmp_ptr);
142 }
InputAndOutput(MSTensor && tensor)143 InputAndOutput::InputAndOutput(MSTensor &&tensor)
144     : cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {}
145 
InputAndOutput(const std::shared_ptr<CellBase> & cell,const std::vector<InputAndOutput> & prev,int32_t index)146 InputAndOutput::InputAndOutput(const std::shared_ptr<CellBase> &cell, const std::vector<InputAndOutput> &prev,
147                                int32_t index)
148     : cell_(cell), prev_(prev), index_(index) {}
149 }  // namespace mindspore
150