• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "utils/tensor_construct_utils.h"
17 #include <memory>
18 #include <vector>
19 #include <map>
20 #include <functional>
21 namespace mindspore {
CreateZerosTensor(const TypePtr & type_ptr,const std::vector<int64_t> & shape)22 tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape) {
23   MS_EXCEPTION_IF_NULL(type_ptr);
24   auto type_id = ExtractTypeId(type_ptr);
25   tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
26   size_t mem_size = IntToSize(tensor->ElementsNum());
27   auto tensor_data = tensor->data_c();
28   char *data = reinterpret_cast<char *>(tensor_data);
29   MS_EXCEPTION_IF_NULL(data);
30   (void)memset_s(data, mem_size, 0, mem_size);
31 
32   return tensor;
33 }
34 
CreateOnesTensor(const TypePtr & type_ptr,const std::vector<int64_t> & shape)35 tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape) {
36   MS_EXCEPTION_IF_NULL(type_ptr);
37   auto type_id = ExtractTypeId(type_ptr);
38   tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
39   const size_t &mem_size = IntToSize(tensor->ElementsNum());
40   auto tensor_data = tensor->data_c();
41   std::map<TypeId, std::function<void()>> type_dict{
42     {kNumberTypeBool, [&tensor_data, mem_size]() { SetTensorData<bool>(tensor_data, true, mem_size); }},
43     {kNumberTypeInt8,
44      [&tensor_data, mem_size]() { SetTensorData<int8_t>(tensor_data, static_cast<int8_t>(1), mem_size); }},
45     {kNumberTypeInt16,
46      [&tensor_data, mem_size]() { SetTensorData<int16_t>(tensor_data, static_cast<int16_t>(1), mem_size); }},
47     {kNumberTypeInt32,
48      [&tensor_data, mem_size]() { SetTensorData<int32_t>(tensor_data, static_cast<int32_t>(1), mem_size); }},
49     {kNumberTypeInt64,
50      [&tensor_data, mem_size]() { SetTensorData<int64_t>(tensor_data, static_cast<int64_t>(1), mem_size); }},
51     {kNumberTypeUInt8,
52      [&tensor_data, mem_size]() { SetTensorData<uint8_t>(tensor_data, static_cast<uint8_t>(1), mem_size); }},
53     {kNumberTypeUInt16,
54      [&tensor_data, mem_size]() { SetTensorData<uint16_t>(tensor_data, static_cast<uint16_t>(1), mem_size); }},
55     {kNumberTypeUInt32,
56      [&tensor_data, mem_size]() { SetTensorData<uint32_t>(tensor_data, static_cast<uint32_t>(1), mem_size); }},
57     {kNumberTypeUInt64,
58      [&tensor_data, mem_size]() { SetTensorData<uint64_t>(tensor_data, static_cast<uint64_t>(1), mem_size); }},
59     {kNumberTypeFloat16,
60      [&tensor_data, mem_size]() { SetTensorData<float16>(tensor_data, static_cast<float16>(1.0), mem_size); }},
61     {kNumberTypeFloat32,
62      [&tensor_data, mem_size]() { SetTensorData<float>(tensor_data, static_cast<float>(1.0), mem_size); }},
63     {kNumberTypeFloat64,
64      [&tensor_data, mem_size]() { SetTensorData<double>(tensor_data, static_cast<double>(1.0), mem_size); }},
65   };
66 
67   const auto &tensor_type = tensor->data_type();
68   if (type_dict.count(tensor_type)) {
69     type_dict[tensor_type]();
70     return tensor;
71   } else {
72     MS_LOG(EXCEPTION) << "unsupported data type: " << tensor_type;
73   }
74 }
75 
CreateTensor(const TypePtr & type_ptr,const std::vector<int64_t> & shape,void * data)76 tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape,
77                                                      void *data) {
78   MS_EXCEPTION_IF_NULL(type_ptr);
79   auto type_id = ExtractTypeId(type_ptr);
80   tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape, data, type_id);
81   return tensor;
82 }
83 
ExtractTypeId(const TypePtr & type_ptr)84 TypeId TensorConstructUtils::ExtractTypeId(const TypePtr &type_ptr) {
85   MS_EXCEPTION_IF_NULL(type_ptr);
86   TypeId type_id;
87   if (type_ptr->isa<TensorType>()) {
88     auto tensor_type = type_ptr->cast<TensorTypePtr>();
89     MS_EXCEPTION_IF_NULL(tensor_type);
90     type_id = tensor_type->element()->type_id();
91   } else {
92     type_id = type_ptr->type_id();
93   }
94   return type_id;
95 }
96 }  // namespace mindspore
97