1 /**
2 * Copyright 2019-2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "utils/tensor_construct_utils.h"
17 #include <memory>
18 #include <vector>
19 #include <map>
20 #include <functional>
21 namespace mindspore {
CreateZerosTensor(const TypePtr & type_ptr,const std::vector<int64_t> & shape)22 tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape) {
23 MS_EXCEPTION_IF_NULL(type_ptr);
24 auto type_id = ExtractTypeId(type_ptr);
25 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
26 size_t mem_size = IntToSize(tensor->ElementsNum());
27 auto tensor_data = tensor->data_c();
28 char *data = reinterpret_cast<char *>(tensor_data);
29 MS_EXCEPTION_IF_NULL(data);
30 (void)memset_s(data, mem_size, 0, mem_size);
31
32 return tensor;
33 }
34
CreateOnesTensor(const TypePtr & type_ptr,const std::vector<int64_t> & shape)35 tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape) {
36 MS_EXCEPTION_IF_NULL(type_ptr);
37 auto type_id = ExtractTypeId(type_ptr);
38 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
39 const size_t &mem_size = IntToSize(tensor->ElementsNum());
40 auto tensor_data = tensor->data_c();
41 std::map<TypeId, std::function<void()>> type_dict{
42 {kNumberTypeBool, [&tensor_data, mem_size]() { SetTensorData<bool>(tensor_data, true, mem_size); }},
43 {kNumberTypeInt8,
44 [&tensor_data, mem_size]() { SetTensorData<int8_t>(tensor_data, static_cast<int8_t>(1), mem_size); }},
45 {kNumberTypeInt16,
46 [&tensor_data, mem_size]() { SetTensorData<int16_t>(tensor_data, static_cast<int16_t>(1), mem_size); }},
47 {kNumberTypeInt32,
48 [&tensor_data, mem_size]() { SetTensorData<int32_t>(tensor_data, static_cast<int32_t>(1), mem_size); }},
49 {kNumberTypeInt64,
50 [&tensor_data, mem_size]() { SetTensorData<int64_t>(tensor_data, static_cast<int64_t>(1), mem_size); }},
51 {kNumberTypeUInt8,
52 [&tensor_data, mem_size]() { SetTensorData<uint8_t>(tensor_data, static_cast<uint8_t>(1), mem_size); }},
53 {kNumberTypeUInt16,
54 [&tensor_data, mem_size]() { SetTensorData<uint16_t>(tensor_data, static_cast<uint16_t>(1), mem_size); }},
55 {kNumberTypeUInt32,
56 [&tensor_data, mem_size]() { SetTensorData<uint32_t>(tensor_data, static_cast<uint32_t>(1), mem_size); }},
57 {kNumberTypeUInt64,
58 [&tensor_data, mem_size]() { SetTensorData<uint64_t>(tensor_data, static_cast<uint64_t>(1), mem_size); }},
59 {kNumberTypeFloat16,
60 [&tensor_data, mem_size]() { SetTensorData<float16>(tensor_data, static_cast<float16>(1.0), mem_size); }},
61 {kNumberTypeFloat32,
62 [&tensor_data, mem_size]() { SetTensorData<float>(tensor_data, static_cast<float>(1.0), mem_size); }},
63 {kNumberTypeFloat64,
64 [&tensor_data, mem_size]() { SetTensorData<double>(tensor_data, static_cast<double>(1.0), mem_size); }},
65 };
66
67 const auto &tensor_type = tensor->data_type();
68 if (type_dict.count(tensor_type)) {
69 type_dict[tensor_type]();
70 return tensor;
71 } else {
72 MS_LOG(EXCEPTION) << "unsupported data type: " << tensor_type;
73 }
74 }
75
CreateTensor(const TypePtr & type_ptr,const std::vector<int64_t> & shape,void * data)76 tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape,
77 void *data) {
78 MS_EXCEPTION_IF_NULL(type_ptr);
79 auto type_id = ExtractTypeId(type_ptr);
80 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape, data, type_id);
81 return tensor;
82 }
83
ExtractTypeId(const TypePtr & type_ptr)84 TypeId TensorConstructUtils::ExtractTypeId(const TypePtr &type_ptr) {
85 MS_EXCEPTION_IF_NULL(type_ptr);
86 TypeId type_id;
87 if (type_ptr->isa<TensorType>()) {
88 auto tensor_type = type_ptr->cast<TensorTypePtr>();
89 MS_EXCEPTION_IF_NULL(tensor_type);
90 type_id = tensor_type->element()->type_id();
91 } else {
92 type_id = type_ptr->type_id();
93 }
94 return type_id;
95 }
96 } // namespace mindspore
97