1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ir/kernel_tensor_value.h"
18
19 namespace mindspore {
20
KernelTensorValue(size_t size,const TypePtr & t)21 KernelTensorValue::KernelTensorValue(size_t size, const TypePtr &t) : Value(t) {
22 if (t) {
23 obj_type_id_ = t->object_type();
24 }
25 mutable_data_ = std::shared_ptr<uint8_t[]>(new (std::nothrow) uint8_t[size]);
26 size_ = size;
27 use_mutable_storage_ = true;
28 }
29
KernelTensorValue(const void * data,size_t size,const TypePtr & t)30 KernelTensorValue::KernelTensorValue(const void *data, size_t size, const TypePtr &t) : Value(t) {
31 if (t) {
32 obj_type_id_ = t->object_type();
33 }
34 mutable_data_ = data;
35 size_ = size;
36 use_mutable_storage_ = true;
37 }
38
KernelTensorValue(const tensor::TensorDataPtr & tensor_data,const TypePtr & t)39 KernelTensorValue::KernelTensorValue(const tensor::TensorDataPtr &tensor_data, const TypePtr &t) : Value(t) {
40 const_data_ = tensor_data;
41 obj_type_id_ = kObjectTypeTensorType;
42 }
43
KernelTensorValue(std::vector<uint8_t> && array_data,const TypePtr & t)44 KernelTensorValue::KernelTensorValue(std::vector<uint8_t> &&array_data, const TypePtr &t) : Value(t) {
45 const_data_ = std::move(array_data);
46 obj_type_id_ = kObjectTypeTuple;
47 }
48
KernelTensorValue(const StringImmPtr & string,const TypePtr & t)49 KernelTensorValue::KernelTensorValue(const StringImmPtr &string, const TypePtr &t) : Value(t) {
50 const_data_ = string;
51 obj_type_id_ = kObjectTypeString;
52 }
53
operator ==(const Value & other) const54 bool KernelTensorValue::operator==(const Value &other) const {
55 if (other.isa<KernelTensorValue>()) {
56 return *this == static_cast<const KernelTensorValue &>(other);
57 } else {
58 return false;
59 }
60 }
61
operator ==(const KernelTensorValue & other) const62 bool KernelTensorValue::operator==(const KernelTensorValue &other) const {
63 if (use_mutable_storage_) {
64 return mutable_data_ == other.mutable_data_;
65 }
66 return const_data_ == other.const_data_;
67 }
68
GetMutableDataPtr()69 void *KernelTensorValue::GetMutableDataPtr() {
70 if (use_mutable_storage_ && std::holds_alternative<std::shared_ptr<uint8_t[]>>(mutable_data_)) {
71 return std::get<std::shared_ptr<uint8_t[]>>(mutable_data_).get();
72 }
73 MS_LOG(EXCEPTION) << "Can not get mutable data pointer for read-only KernelTensorValue.";
74 }
75
GetDataPtr() const76 const void *KernelTensorValue::GetDataPtr() const {
77 if (use_mutable_storage_) {
78 if (std::holds_alternative<std::shared_ptr<uint8_t[]>>(mutable_data_)) {
79 return std::get<std::shared_ptr<uint8_t[]>>(mutable_data_).get();
80 }
81 return std::get<const void *>(mutable_data_);
82 }
83
84 switch (obj_type_id_) {
85 case kObjectTypeNumber:
86 case kObjectTypeTuple: {
87 const std::vector<uint8_t> &data = std::get<std::vector<uint8_t>>(const_data_);
88 if (data.empty()) {
89 return nullptr;
90 }
91 return data.data();
92 }
93
94 case kObjectTypeTensorType: {
95 const tensor::TensorDataPtr &tensor_data = std::get<tensor::TensorDataPtr>(const_data_);
96 MS_EXCEPTION_IF_NULL(tensor_data);
97 return tensor_data->data();
98 }
99
100 case kObjectTypeString: {
101 const StringImmPtr &string_imm = std::get<StringImmPtr>(const_data_);
102 MS_EXCEPTION_IF_NULL(string_imm);
103 return string_imm->value().data();
104 }
105
106 default:
107 MS_LOG(EXCEPTION) << "Can not get data pointer for type: " << TypeIdLabel(obj_type_id_);
108 }
109 }
110
GetDataSize() const111 size_t KernelTensorValue::GetDataSize() const {
112 if (use_mutable_storage_) {
113 return size_;
114 }
115
116 switch (obj_type_id_) {
117 case kObjectTypeNumber:
118 case kObjectTypeTuple: {
119 const std::vector<uint8_t> &data = std::get<std::vector<uint8_t>>(const_data_);
120 if (data.empty()) {
121 return 0;
122 }
123 return data.size();
124 }
125
126 case kObjectTypeTensorType: {
127 const tensor::TensorDataPtr &tensor_data = std::get<tensor::TensorDataPtr>(const_data_);
128 MS_EXCEPTION_IF_NULL(tensor_data);
129 return tensor_data->nbytes();
130 }
131
132 case kObjectTypeString: {
133 const StringImmPtr &string_imm = std::get<StringImmPtr>(const_data_);
134 MS_EXCEPTION_IF_NULL(string_imm);
135 return string_imm->value().size();
136 }
137
138 default:
139 MS_LOG(EXCEPTION) << "Can not get data size for type: " << TypeIdLabel(obj_type_id_);
140 }
141 }
142
SetDataPtr(const void * data_ptr)143 void KernelTensorValue::SetDataPtr(const void *data_ptr) {
144 MS_EXCEPTION_IF_NULL(data_ptr);
145 if (!use_mutable_storage_) {
146 MS_LOG(EXCEPTION) << "Can not set data for const KernelTensorValue.";
147 }
148 if (std::holds_alternative<const void *>(mutable_data_)) {
149 mutable_data_ = data_ptr;
150 return;
151 }
152
153 MS_LOG(EXCEPTION) << "Can not set data pointer for KernelTensorValue which uses shared pointer to storage data.";
154 }
155
Resize(size_t size)156 void KernelTensorValue::Resize(size_t size) {
157 if (!use_mutable_storage_) {
158 MS_LOG(EXCEPTION) << "Can not resize const KernelTensorValue.";
159 }
160
161 if (std::holds_alternative<std::shared_ptr<uint8_t[]>>(mutable_data_)) {
162 if (size_ < size) {
163 mutable_data_ = std::shared_ptr<uint8_t[]>(new (std::nothrow) uint8_t[size]);
164 }
165 }
166 size_ = size;
167 }
168 } // namespace mindspore
169