1 /**
2 * Copyright 2019-2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "utils/tensor_construct_utils.h"
17 #include <memory>
18 #include <vector>
19 #include <map>
20 #include <functional>
21 namespace mindspore {
CreateZerosTensor(const TypePtr & type,const std::vector<int64_t> & shape)22 tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr &type, const std::vector<int64_t> &shape) {
23 MS_EXCEPTION_IF_NULL(type);
24 auto type_id = ExtractTypeId(type);
25 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
26 size_t mem_size = LongToSize(tensor->ElementsNum());
27 auto tensor_data = tensor->data_c();
28 char *data = reinterpret_cast<char *>(tensor_data);
29 MS_EXCEPTION_IF_NULL(data);
30 if (memset_s(data, mem_size, 0, mem_size) != EOK) {
31 MS_LOG(ERROR) << "Cannot create zeros tensor.";
32 return nullptr;
33 }
34
35 return tensor;
36 }
37
CreateOnesTensor(const TypePtr & type,const std::vector<int64_t> & shape,bool skip_exception)38 tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr &type, const std::vector<int64_t> &shape,
39 bool skip_exception) {
40 MS_EXCEPTION_IF_NULL(type);
41 auto type_id = ExtractTypeId(type);
42 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
43 const size_t &mem_size = LongToSize(tensor->ElementsNum());
44 auto tensor_data = tensor->data_c();
45 std::map<TypeId, std::function<void()>> type_dict{
46 {kNumberTypeBool, [&tensor_data, mem_size]() { SetTensorData<bool>(tensor_data, true, mem_size); }},
47 {kNumberTypeInt8,
48 [&tensor_data, mem_size]() { SetTensorData<int8_t>(tensor_data, static_cast<int8_t>(1), mem_size); }},
49 {kNumberTypeInt16,
50 [&tensor_data, mem_size]() { SetTensorData<int16_t>(tensor_data, static_cast<int16_t>(1), mem_size); }},
51 {kNumberTypeInt32,
52 [&tensor_data, mem_size]() { SetTensorData<int32_t>(tensor_data, static_cast<int32_t>(1), mem_size); }},
53 {kNumberTypeInt64,
54 [&tensor_data, mem_size]() { SetTensorData<int64_t>(tensor_data, static_cast<int64_t>(1), mem_size); }},
55 {kNumberTypeUInt8,
56 [&tensor_data, mem_size]() { SetTensorData<uint8_t>(tensor_data, static_cast<uint8_t>(1), mem_size); }},
57 {kNumberTypeUInt16,
58 [&tensor_data, mem_size]() { SetTensorData<uint16_t>(tensor_data, static_cast<uint16_t>(1), mem_size); }},
59 {kNumberTypeUInt32,
60 [&tensor_data, mem_size]() { SetTensorData<uint32_t>(tensor_data, static_cast<uint32_t>(1), mem_size); }},
61 {kNumberTypeUInt64,
62 [&tensor_data, mem_size]() { SetTensorData<uint64_t>(tensor_data, static_cast<uint64_t>(1), mem_size); }},
63 {kNumberTypeFloat16,
64 [&tensor_data, mem_size]() { SetTensorData<float16>(tensor_data, static_cast<float16>(1.0), mem_size); }},
65 {kNumberTypeFloat32,
66 [&tensor_data, mem_size]() { SetTensorData<float>(tensor_data, static_cast<float>(1.0), mem_size); }},
67 {kNumberTypeFloat64,
68 [&tensor_data, mem_size]() { SetTensorData<double>(tensor_data, static_cast<double>(1.0), mem_size); }},
69 {kNumberTypeBFloat16,
70 [&tensor_data, mem_size]() { SetTensorData<bfloat16>(tensor_data, static_cast<bfloat16>(1.0), mem_size); }},
71 };
72
73 const auto &tensor_type = tensor->data_type();
74 auto iter = type_dict.find(tensor_type);
75 if (iter == type_dict.end()) {
76 if (skip_exception) {
77 return nullptr;
78 }
79 MS_LOG(EXCEPTION) << "unsupported data type: " << tensor_type;
80 }
81 iter->second();
82 return tensor;
83 }
84
CreateTensor(const TypePtr & type,const std::vector<int64_t> & shape,void * data)85 tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr &type, const std::vector<int64_t> &shape,
86 void *data) {
87 MS_EXCEPTION_IF_NULL(type);
88 auto type_id = ExtractTypeId(type);
89 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape, data, type_id);
90 return tensor;
91 }
92
ExtractTypeId(const TypePtr & type)93 TypeId TensorConstructUtils::ExtractTypeId(const TypePtr &type) {
94 MS_EXCEPTION_IF_NULL(type);
95 TypeId type_id;
96 if (type->isa<TensorType>()) {
97 auto tensor_type = type->cast<TensorTypePtr>();
98 MS_EXCEPTION_IF_NULL(tensor_type);
99 type_id = tensor_type->element()->type_id();
100 } else {
101 type_id = type->type_id();
102 }
103 return type_id;
104 }
105 } // namespace mindspore
106