• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "utils/tensor_construct_utils.h"
17 #include <memory>
18 #include <vector>
19 #include <map>
20 #include <functional>
21 namespace mindspore {
CreateZerosTensor(const TypePtr & type,const std::vector<int64_t> & shape)22 tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr &type, const std::vector<int64_t> &shape) {
23   MS_EXCEPTION_IF_NULL(type);
24   auto type_id = ExtractTypeId(type);
25   tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
26   size_t mem_size = LongToSize(tensor->ElementsNum());
27   auto tensor_data = tensor->data_c();
28   char *data = reinterpret_cast<char *>(tensor_data);
29   MS_EXCEPTION_IF_NULL(data);
30   if (memset_s(data, mem_size, 0, mem_size) != EOK) {
31     MS_LOG(ERROR) << "Cannot create zeros tensor.";
32     return nullptr;
33   }
34 
35   return tensor;
36 }
37 
CreateOnesTensor(const TypePtr & type,const std::vector<int64_t> & shape,bool skip_exception)38 tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr &type, const std::vector<int64_t> &shape,
39                                                          bool skip_exception) {
40   MS_EXCEPTION_IF_NULL(type);
41   auto type_id = ExtractTypeId(type);
42   tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
43   const size_t &mem_size = LongToSize(tensor->ElementsNum());
44   auto tensor_data = tensor->data_c();
45   std::map<TypeId, std::function<void()>> type_dict{
46     {kNumberTypeBool, [&tensor_data, mem_size]() { SetTensorData<bool>(tensor_data, true, mem_size); }},
47     {kNumberTypeInt8,
48      [&tensor_data, mem_size]() { SetTensorData<int8_t>(tensor_data, static_cast<int8_t>(1), mem_size); }},
49     {kNumberTypeInt16,
50      [&tensor_data, mem_size]() { SetTensorData<int16_t>(tensor_data, static_cast<int16_t>(1), mem_size); }},
51     {kNumberTypeInt32,
52      [&tensor_data, mem_size]() { SetTensorData<int32_t>(tensor_data, static_cast<int32_t>(1), mem_size); }},
53     {kNumberTypeInt64,
54      [&tensor_data, mem_size]() { SetTensorData<int64_t>(tensor_data, static_cast<int64_t>(1), mem_size); }},
55     {kNumberTypeUInt8,
56      [&tensor_data, mem_size]() { SetTensorData<uint8_t>(tensor_data, static_cast<uint8_t>(1), mem_size); }},
57     {kNumberTypeUInt16,
58      [&tensor_data, mem_size]() { SetTensorData<uint16_t>(tensor_data, static_cast<uint16_t>(1), mem_size); }},
59     {kNumberTypeUInt32,
60      [&tensor_data, mem_size]() { SetTensorData<uint32_t>(tensor_data, static_cast<uint32_t>(1), mem_size); }},
61     {kNumberTypeUInt64,
62      [&tensor_data, mem_size]() { SetTensorData<uint64_t>(tensor_data, static_cast<uint64_t>(1), mem_size); }},
63     {kNumberTypeFloat16,
64      [&tensor_data, mem_size]() { SetTensorData<float16>(tensor_data, static_cast<float16>(1.0), mem_size); }},
65     {kNumberTypeFloat32,
66      [&tensor_data, mem_size]() { SetTensorData<float>(tensor_data, static_cast<float>(1.0), mem_size); }},
67     {kNumberTypeFloat64,
68      [&tensor_data, mem_size]() { SetTensorData<double>(tensor_data, static_cast<double>(1.0), mem_size); }},
69     {kNumberTypeBFloat16,
70      [&tensor_data, mem_size]() { SetTensorData<bfloat16>(tensor_data, static_cast<bfloat16>(1.0), mem_size); }},
71   };
72 
73   const auto &tensor_type = tensor->data_type();
74   auto iter = type_dict.find(tensor_type);
75   if (iter == type_dict.end()) {
76     if (skip_exception) {
77       return nullptr;
78     }
79     MS_LOG(EXCEPTION) << "unsupported data type: " << tensor_type;
80   }
81   iter->second();
82   return tensor;
83 }
84 
CreateTensor(const TypePtr & type,const std::vector<int64_t> & shape,void * data)85 tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr &type, const std::vector<int64_t> &shape,
86                                                      void *data) {
87   MS_EXCEPTION_IF_NULL(type);
88   auto type_id = ExtractTypeId(type);
89   tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape, data, type_id);
90   return tensor;
91 }
92 
ExtractTypeId(const TypePtr & type)93 TypeId TensorConstructUtils::ExtractTypeId(const TypePtr &type) {
94   MS_EXCEPTION_IF_NULL(type);
95   TypeId type_id;
96   if (type->isa<TensorType>()) {
97     auto tensor_type = type->cast<TensorTypePtr>();
98     MS_EXCEPTION_IF_NULL(tensor_type);
99     type_id = tensor_type->element()->type_id();
100   } else {
101     type_id = type->type_id();
102   }
103   return type_id;
104 }
105 }  // namespace mindspore
106