• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ir/kernel_tensor_value.h"
18 
19 namespace mindspore {
20 
KernelTensorValue(size_t size,const TypePtr & t)21 KernelTensorValue::KernelTensorValue(size_t size, const TypePtr &t) : Value(t) {
22   if (t) {
23     obj_type_id_ = t->object_type();
24   }
25   mutable_data_ = std::shared_ptr<uint8_t[]>(new (std::nothrow) uint8_t[size]);
26   size_ = size;
27   use_mutable_storage_ = true;
28 }
29 
KernelTensorValue(const void * data,size_t size,const TypePtr & t)30 KernelTensorValue::KernelTensorValue(const void *data, size_t size, const TypePtr &t) : Value(t) {
31   if (t) {
32     obj_type_id_ = t->object_type();
33   }
34   mutable_data_ = data;
35   size_ = size;
36   use_mutable_storage_ = true;
37 }
38 
KernelTensorValue(const tensor::TensorDataPtr & tensor_data,const TypePtr & t)39 KernelTensorValue::KernelTensorValue(const tensor::TensorDataPtr &tensor_data, const TypePtr &t) : Value(t) {
40   const_data_ = tensor_data;
41   obj_type_id_ = kObjectTypeTensorType;
42 }
43 
KernelTensorValue(std::vector<uint8_t> && array_data,const TypePtr & t)44 KernelTensorValue::KernelTensorValue(std::vector<uint8_t> &&array_data, const TypePtr &t) : Value(t) {
45   const_data_ = std::move(array_data);
46   obj_type_id_ = kObjectTypeTuple;
47 }
48 
KernelTensorValue(const StringImmPtr & string,const TypePtr & t)49 KernelTensorValue::KernelTensorValue(const StringImmPtr &string, const TypePtr &t) : Value(t) {
50   const_data_ = string;
51   obj_type_id_ = kObjectTypeString;
52 }
53 
operator ==(const Value & other) const54 bool KernelTensorValue::operator==(const Value &other) const {
55   if (other.isa<KernelTensorValue>()) {
56     return *this == static_cast<const KernelTensorValue &>(other);
57   } else {
58     return false;
59   }
60 }
61 
operator ==(const KernelTensorValue & other) const62 bool KernelTensorValue::operator==(const KernelTensorValue &other) const {
63   if (use_mutable_storage_) {
64     return mutable_data_ == other.mutable_data_;
65   }
66   return const_data_ == other.const_data_;
67 }
68 
GetMutableDataPtr()69 void *KernelTensorValue::GetMutableDataPtr() {
70   if (use_mutable_storage_ && std::holds_alternative<std::shared_ptr<uint8_t[]>>(mutable_data_)) {
71     return std::get<std::shared_ptr<uint8_t[]>>(mutable_data_).get();
72   }
73   MS_LOG(EXCEPTION) << "Can not get mutable data pointer for read-only KernelTensorValue.";
74 }
75 
GetDataPtr() const76 const void *KernelTensorValue::GetDataPtr() const {
77   if (use_mutable_storage_) {
78     if (std::holds_alternative<std::shared_ptr<uint8_t[]>>(mutable_data_)) {
79       return std::get<std::shared_ptr<uint8_t[]>>(mutable_data_).get();
80     }
81     return std::get<const void *>(mutable_data_);
82   }
83 
84   switch (obj_type_id_) {
85     case kObjectTypeNumber:
86     case kObjectTypeTuple: {
87       const std::vector<uint8_t> &data = std::get<std::vector<uint8_t>>(const_data_);
88       if (data.empty()) {
89         return nullptr;
90       }
91       return data.data();
92     }
93 
94     case kObjectTypeTensorType: {
95       const tensor::TensorDataPtr &tensor_data = std::get<tensor::TensorDataPtr>(const_data_);
96       MS_EXCEPTION_IF_NULL(tensor_data);
97       return tensor_data->data();
98     }
99 
100     case kObjectTypeString: {
101       const StringImmPtr &string_imm = std::get<StringImmPtr>(const_data_);
102       MS_EXCEPTION_IF_NULL(string_imm);
103       return string_imm->value().data();
104     }
105 
106     default:
107       MS_LOG(EXCEPTION) << "Can not get data pointer for type: " << TypeIdLabel(obj_type_id_);
108   }
109 }
110 
GetDataSize() const111 size_t KernelTensorValue::GetDataSize() const {
112   if (use_mutable_storage_) {
113     return size_;
114   }
115 
116   switch (obj_type_id_) {
117     case kObjectTypeNumber:
118     case kObjectTypeTuple: {
119       const std::vector<uint8_t> &data = std::get<std::vector<uint8_t>>(const_data_);
120       if (data.empty()) {
121         return 0;
122       }
123       return data.size();
124     }
125 
126     case kObjectTypeTensorType: {
127       const tensor::TensorDataPtr &tensor_data = std::get<tensor::TensorDataPtr>(const_data_);
128       MS_EXCEPTION_IF_NULL(tensor_data);
129       return tensor_data->nbytes();
130     }
131 
132     case kObjectTypeString: {
133       const StringImmPtr &string_imm = std::get<StringImmPtr>(const_data_);
134       MS_EXCEPTION_IF_NULL(string_imm);
135       return string_imm->value().size();
136     }
137 
138     default:
139       MS_LOG(EXCEPTION) << "Can not get data size for type: " << TypeIdLabel(obj_type_id_);
140   }
141 }
142 
SetDataPtr(const void * data_ptr)143 void KernelTensorValue::SetDataPtr(const void *data_ptr) {
144   MS_EXCEPTION_IF_NULL(data_ptr);
145   if (!use_mutable_storage_) {
146     MS_LOG(EXCEPTION) << "Can not set data for const KernelTensorValue.";
147   }
148   if (std::holds_alternative<const void *>(mutable_data_)) {
149     mutable_data_ = data_ptr;
150     return;
151   }
152 
153   MS_LOG(EXCEPTION) << "Can not set data pointer for KernelTensorValue which uses shared pointer to storage data.";
154 }
155 
Resize(size_t size)156 void KernelTensorValue::Resize(size_t size) {
157   if (!use_mutable_storage_) {
158     MS_LOG(EXCEPTION) << "Can not resize const KernelTensorValue.";
159   }
160 
161   if (std::holds_alternative<std::shared_ptr<uint8_t[]>>(mutable_data_)) {
162     if (size_ < size) {
163       mutable_data_ = std::shared_ptr<uint8_t[]>(new (std::nothrow) uint8_t[size]);
164     }
165   }
166   size_ = size;
167 }
168 }  // namespace mindspore
169