1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/cxx_api/tensor_utils.h"
18 #include "src/common/log_adapter.h"
19 #include "src/tensor.h"
20
21 namespace mindspore {
TruncateShape(const std::vector<int64_t> & shape,enum TypeId type,size_t data_len,bool verify_size)22 std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
23 bool verify_size) {
24 std::vector<int32_t> empty;
25 if (shape.empty()) {
26 return empty;
27 }
28 std::vector<int32_t> truncated_shape;
29 truncated_shape.resize(shape.size());
30 size_t element_size = lite::DataTypeSize(type);
31 for (size_t i = 0; i < shape.size(); i++) {
32 auto dim = shape[i];
33 if (dim < 0 || dim > INT_MAX || (dim != 0 && element_size > INT_MAX / static_cast<size_t>(dim))) {
34 MS_LOG(ERROR) << "Invalid shape.";
35 return empty;
36 } else {
37 element_size *= static_cast<size_t>(dim);
38 truncated_shape[i] = static_cast<int32_t>(dim);
39 }
40 }
41 if (verify_size) {
42 if (element_size != data_len) {
43 MS_LOG(ERROR) << "Invalid data size.";
44 return empty;
45 }
46 }
47 return truncated_shape;
48 }
LiteTensorToMSTensor(tensor::MSTensor * srcTensor,MSTensor * dstTensor,bool fromSession)49 Status LiteTensorToMSTensor(tensor::MSTensor *srcTensor, MSTensor *dstTensor, bool fromSession) {
50 auto impl = std::make_shared<MSTensor::Impl>(srcTensor);
51 if (impl == nullptr || impl->lite_tensor() == nullptr) {
52 MS_LOG(ERROR) << "Create tensor failed.";
53 return kLiteError;
54 }
55 impl->set_from_session(fromSession);
56 auto tensor = MSTensor(impl);
57 if (tensor == nullptr) {
58 MS_LOG(ERROR) << "Create tensor failed.";
59 return kLiteError;
60 }
61 *dstTensor = tensor;
62 return kSuccess;
63 }
64
LiteTensorsToMSTensors(const std::vector<mindspore::tensor::MSTensor * > & srcTensors,bool fromSession)65 std::vector<MSTensor> LiteTensorsToMSTensors(const std::vector<mindspore::tensor::MSTensor *> &srcTensors,
66 bool fromSession) {
67 std::vector<MSTensor> dstTensors;
68 dstTensors.reserve(srcTensors.size());
69 for (auto inTensor : srcTensors) {
70 MSTensor tensor;
71 auto status = LiteTensorToMSTensor(inTensor, &tensor, fromSession);
72 if (status != kSuccess) {
73 return {};
74 }
75 dstTensors.emplace_back(tensor);
76 }
77 return dstTensors;
78 }
79 } // namespace mindspore
80