1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/delegate/coreml/coreml_graph.h"
18 #include <fstream>
19 namespace mindspore::lite {
~CoreMLGraph()20 CoreMLGraph::~CoreMLGraph() {
21 for (auto *kernel : all_kernels_) {
22 delete kernel;
23 }
24 for (auto *op : coreml_ops_) {
25 delete op;
26 }
27 for (auto tensor : insert_tensors_) {
28 MSTensor::DestroyTensorPtr(tensor);
29 }
30 delete ml_model_;
31 delete executor_wrapper_;
32 }
33
set_input(mindspore::MSTensor in_tensor,int index)34 void CoreMLGraph::set_input(mindspore::MSTensor in_tensor, int index) {
35 MS_ASSERT(static_cast<size_t>(index) < inputs_.size());
36 auto origin_tensor = this->inputs_[index];
37 for (auto kernel : all_kernels_) {
38 for (size_t i = 0; i < kernel->size(); i++) {
39 if (kernel->inputs()[i] == origin_tensor) {
40 kernel->set_input(in_tensor, i);
41 }
42 }
43 }
44 this->inputs_[index] = in_tensor;
45 }
46
set_output(mindspore::MSTensor out_tensor,int index)47 void CoreMLGraph::set_output(mindspore::MSTensor out_tensor, int index) {
48 MS_ASSERT(static_cast<size_t>(index) < outputs_.size());
49 auto origin_tensor = this->outputs_[index];
50 for (auto kernel : all_kernels_) {
51 for (size_t i = 0; i < kernel->outputs().size(); i++) {
52 if (kernel->outputs()[i] == origin_tensor) {
53 kernel->set_output(out_tensor, i);
54 }
55 }
56 }
57 this->outputs_[index] = out_tensor;
58 }
59
Init()60 int CoreMLGraph::Init() {
61 ml_model_ = BuildMLModel();
62 if (ml_model_ == nullptr) {
63 MS_LOG(ERROR) << "Build CoreML model failed.";
64 return RET_ERROR;
65 }
66 auto model_path = SaveMLModel();
67 executor_wrapper_ = new (std::nothrow) CoreMLExecutorWrapper();
68 if (executor_wrapper_ == nullptr) {
69 MS_LOG(ERROR) << "Create CoreML executor wrapper failed.";
70 return RET_ERROR;
71 }
72 auto ret = executor_wrapper_->CompileMLModel(model_path);
73 if (ret != RET_OK) {
74 MS_LOG(ERROR) << "Compile coreML model failed!";
75 return RET_ERROR;
76 }
77 return RET_OK;
78 }
79
BuildMLModel()80 CoreML::Specification::Model *CoreMLGraph::BuildMLModel() {
81 auto *model = new (std::nothrow) CoreML::Specification::Model();
82 model->set_specificationversion(kCoreMLVersion4);
83 model->mutable_neuralnetwork()->set_arrayinputshapemapping(CoreML::Specification::EXACT_ARRAY_MAPPING);
84 auto *network = model->mutable_neuralnetwork();
85 for (auto &op : coreml_ops_) {
86 auto ret = op->BuildLayer();
87 if (ret != RET_OK) {
88 MS_LOG(ERROR) << "Failed to build layer for op: " << op->name();
89 delete model;
90 model = nullptr;
91 return nullptr;
92 }
93 op->SetMLOpInOut();
94 auto layers = op->GetLayers();
95 if (layers.empty()) {
96 MS_LOG(ERROR) << "No layer found for op: " << op->name();
97 delete model;
98 model = nullptr;
99 return nullptr;
100 }
101 for (auto layer : layers) {
102 MS_ASSERT(layer != nullptr);
103 network->mutable_layers()->AddAllocated(layer);
104 }
105 }
106 auto ret = SetMLModelInOut(model);
107 if (ret != RET_OK) {
108 MS_LOG(ERROR) << "Set model input output failed.";
109 delete model;
110 model = nullptr;
111 return nullptr;
112 }
113 return model;
114 }
115
SetMLModelInOut(CoreML::Specification::Model * model)116 int CoreMLGraph::SetMLModelInOut(CoreML::Specification::Model *model) {
117 MS_ASSERT(model != nullptr);
118 auto model_desc = model->mutable_description();
119 for (const auto &in_tensor : this->inputs_) {
120 // add input
121 auto input = model_desc->add_input();
122 input->set_name(in_tensor.Name());
123 auto in_multi_array = input->mutable_type()->mutable_multiarraytype();
124 if (in_tensor.DataType() == DataType::kNumberTypeFloat32) {
125 in_multi_array->set_datatype(CoreML::Specification::ArrayFeatureType::FLOAT32);
126 } else if (in_tensor.DataType() == DataType::kNumberTypeInt32) {
127 in_multi_array->set_datatype(CoreML::Specification::ArrayFeatureType::INT32);
128 } else {
129 MS_LOG(ERROR) << "Unsupported model input data type: " << static_cast<int>(in_tensor.DataType());
130 return RET_ERROR;
131 }
132 for (int64_t i : in_tensor.Shape()) {
133 in_multi_array->add_shape(static_cast<uint64_t>(i));
134 }
135 }
136 for (const auto &out_tensor : this->outputs_) {
137 // add output
138 auto output = model_desc->add_output();
139 output->set_name(out_tensor.Name());
140 auto out_multi_array = output->mutable_type()->mutable_multiarraytype();
141 if (out_tensor.DataType() == DataType::kNumberTypeFloat32) {
142 out_multi_array->set_datatype(CoreML::Specification::ArrayFeatureType::FLOAT32);
143 } else if (out_tensor.DataType() == DataType::kNumberTypeInt32) {
144 out_multi_array->set_datatype(CoreML::Specification::ArrayFeatureType::INT32);
145 } else {
146 MS_LOG(ERROR) << "Unsupported model output data type: " << static_cast<int>(out_tensor.DataType());
147 return RET_ERROR;
148 }
149 for (int64_t i : out_tensor.Shape()) {
150 out_multi_array->add_shape(static_cast<uint64_t>(i));
151 }
152 }
153 return RET_OK;
154 }
155
SaveMLModel()156 std::string CoreMLGraph::SaveMLModel() {
157 MS_ASSERT(ml_model_ != nullptr);
158 std::string model_name = this->name() + ".mlmodel";
159 auto model_path = std::string(getenv("HOME")) + "/tmp/" + model_name;
160 std::ofstream file_stream(model_path, std::ios::out | std::ios::binary);
161 ml_model_->SerializeToOstream(&file_stream);
162 MS_LOG(INFO) << "Build CoreML model success!";
163 return model_path;
164 }
165
Execute()166 int CoreMLGraph::Execute() {
167 auto ret = executor_wrapper_->Run(inputs(), outputs());
168 MS_LOG(INFO) << "run model success!";
169 return ret;
170 }
171 } // namespace mindspore::lite
172