• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_TENSOR_H_
18 #define MINDSPORE_CORE_IR_TENSOR_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include <numeric>
24 #include <mutex>
25 #include <condition_variable>
26 
27 #include "ir/device_sync.h"
28 #include "ir/meta_tensor.h"
29 #include "utils/log_adapter.h"
30 #include "base/float16.h"
31 #include "utils/shape_utils.h"
32 #include "utils/ms_exception.h"
33 #include "ir/device_event.h"
34 
35 // brief mindspore namespace.
36 //
37 // mindspore namespace is the top level namespace of MindSpore project.
38 // Other namespace should be a sub namespace of mindspore namespace in the ME project.
39 namespace mindspore {
40 // brief mindspore::tensor namespace
41 enum TensorSyncStatus { kNoNeedSync, kNeedSyncHostToDevice, kNeedSyncDeviceToHost, kNeedSyncDeviceToHostImmediately };
42 // A sub namespace in ME to support tensor related definition.
43 namespace tensor {
44 // Tensor data interface.
45 class MS_CORE_API TensorData {
46  public:
47   /// virtual destructor is required for base classes.
48   virtual ~TensorData() = default;
49   /// Total number of elements.
50   virtual ssize_t size() const = 0;
51   /// Byte size of a single element.
52   virtual ssize_t itemsize() const = 0;
53   /// Total number of bytes.
54   virtual ssize_t nbytes() const = 0;
55   /// Number of dimensions.
56   virtual ssize_t ndim() const = 0;
57   /// Data pointer.
58   virtual void *data() = 0;
59   /// Const Data pointer.
60   virtual const void *const_data() const = 0;
61   /// Is data equals.
equals(const TensorData & other)62   virtual bool equals(const TensorData &other) const {
63     if (this == &other) {
64       return true;
65     }
66     // By default, compare data byte by byte.
67     auto this_data = static_cast<const uint8_t *>(const_data());
68     auto other_data = static_cast<const uint8_t *>(other.const_data());
69     if (this_data == nullptr || other_data == nullptr) {
70       // null means data not initialized, compare uninitialized data always return false.
71       return false;
72     }
73     return (this_data == other_data) || (ndim() == other.ndim() && nbytes() == other.nbytes() &&
74                                          std::equal(this_data, this_data + nbytes(), other_data));
75   }
76   /// To string.
77   virtual std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const = 0;
78 };
79 
80 using TensorDataPtr = std::shared_ptr<TensorData>;
81 
82 class WaitEvent : public ExceptionListener {
83  public:
OnException()84   void OnException() override { set_need_wait(false); }
85 
Wait()86   void Wait() const {
87     std::unique_lock<std::mutex> lock(mutex_);
88     if (!need_wait_) {
89       return;
90     }
91     MsException::Instance().SetExceptionListener(const_cast<WaitEvent *>(this));
92     cond_var_.wait(lock, [this] { return !need_wait_; });
93     MsException::Instance().SetExceptionListener(nullptr);
94     MsException::Instance().CheckException();
95   }
96 
set_need_wait(bool need_wait)97   void set_need_wait(bool need_wait) {
98     std::unique_lock<std::mutex> lock(mutex_);
99     need_wait_ = need_wait;
100     if (!need_wait_) {
101       cond_var_.notify_all();
102     }
103   }
104 
need_wait()105   bool need_wait() const { return need_wait_; }
106 
107  private:
108   bool need_wait_{false};
109   mutable std::mutex mutex_;
110   mutable std::condition_variable cond_var_;
111 };
112 
113 // Tensor entity class
114 class MS_CORE_API Tensor : public MetaTensor {
115  public:
116   abstract::AbstractBasePtr ToAbstract() override;
117 
118   // brief Create tensor from another tensor, data is shared.
119   //
120   // param tensor [Tensor] The input tensor.
121   explicit Tensor(const Tensor &tensor);
122 
123   // brief Create tensor with given data type from another tensor.
124   //
125   // param tensor [Tensor] The input tensor.
126   // param data_type [TypeId] The new tensor data type.
127   Tensor(const Tensor &tensor, TypeId data_type);
128 
129   // brief Create tensor with the given shared tensor data.
130   //
131   // param data_type [TypeId] Data type of the tensor.
132   // param shape The shape represented by ShapeVector of the tensor.
133   // param data The shared tensor data.
134   Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data);
135 
136   // brief Create a lazy allocated tensor.
137   //
138   // param data_type [TypeId] Data type of the tensor.
139   // param shape The shape represented by ShapeVector of the tensor.
140   Tensor(TypeId data_type, const ShapeVector &shape);
141 
142   // brief Create a tensor with input data buffer.
143   //
144   // param data_type [TypeId] Data type of the tensor.
145   // param shape The shape represented by ShapeVector of the tensor.
146   // param data The input data to be copied into tensor.
147   // param data_len The length of data in bytes.
148   Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len);
149 
150   // brief Create a tensor with input data buffer and given source data type.
151   //
152   // param data_type [TypeId] Data type of the tensor.
153   // param shape The shape represented by ShapeVector of the tensor.
154   // param data The input data to be copied into tensor.
155   // param src_data_type The source data type.
156   Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type);
157 
158   // brief Create 1 dimension tensor from an int vector.
159   //
160   // param input [std::vector<int64_t>] the data for tensor
161   // param data_type [TypeId] data type
162   explicit Tensor(const std::vector<int64_t> &input, const TypePtr &data_type = nullptr);
163 
164   // brief Create 1 dimension tensor from a float vector.
165   //
166   // param input [std::vector<double>] the data for tensor
167   // param data_type [TypeId] data type
168   explicit Tensor(const std::vector<double> &input, const TypePtr &data_type = nullptr);
169 
170   // brief Create 0 dimension tensor from an int64_t scalar.
171   //
172   // param input [int64] the data for tensor
173   // param data_type [TypeId] data type
174   explicit Tensor(int64_t input, const TypePtr &data_type = nullptr);
175 
176   // brief Create 0 dimension tensor from a float scalar.
177   //
178   // param input [double] the data for tensor
179   // param data_type [TypeId] data type
180   explicit Tensor(double input, const TypePtr &data_type = nullptr);
181 
182   // brief Create 0 dimension tensor from a uint scalar.
183   //
184   // param input [uint] the data for tensor
185   // param data_type [TypeId] data type
186   explicit Tensor(uint64_t input, const TypePtr &data_type = nullptr);
187 
188   // brief Create 0 dimension tensor from a bool scalar.
189   //
190   // param input [bool] the data for tensor
191   // param data_type [TypeId] data type
192   explicit Tensor(bool input, const TypePtr &data_type = nullptr);
193 
194   ~Tensor() override = default;
195 
196   MS_DECLARE_PARENT(Tensor, MetaTensor);
197 
198   // brief Compares two Tensor objects.
199   //
200   // Compare two tensor objects to see if they have same data type, shape and data address.
201   //
202   // param tensor The Tensor object to be compared.
203   // return true: If having same type, shape and data address, return true, or return false.
204   bool operator==(const Tensor &tensor) const;
205 
206   // It is different from 'operator==' which just compare shape/type/address,
207   // it do real value comparison.
208   bool ValueEqual(const Tensor &tensor) const;
209 
210   // assign value to this tensor
211   Tensor &AssignValue(const Tensor &tensor);
212 
213   bool operator==(const Value &other) const override {
214     if (other.isa<Tensor>()) {
215       auto &other_ = static_cast<const Tensor &>(other);
216       return *this == other_;
217     }
218     return false;
219   }
220 
221   // brief Gets tensor's dimension
222   //
223   // return The number of dimensions of the tensor data.
DataDim()224   int DataDim() const { return static_cast<int>(data().ndim()); }
225 
226   // brief Getting tensor data size
227   //
228   // return The total number of elements of the tensor data.
DataSize()229   int DataSize() const { return static_cast<int>(data().size()); }
230 
231   // brief Get the data type fo the tensor for C++
232   //
233   // return [int] The tensor's data type will be cast to int to return.
data_type_c()234   int data_type_c() const { return static_cast<int>(data_type_); }
235 
236   // brief Get the tensor's shape for C++
237   //
238   // return [ShapeVector]
shape_c(void)239   ShapeVector shape_c(void) const { return shape(); }
240 
241   // brief Get Tensor data pointer for c++ type
242   //
243   // return The pointer to the object
data_c()244   void *data_c() { return data().data(); }
245 
246   // brief Get Tensor data byte-size for c++ type
247   //
248   // return byte size of Tensor data
Size()249   size_t Size() const { return static_cast<size_t>(data().nbytes()); }
250 
data_c()251   void *data_c() const { return data_->data(); }
252 
253   // brief Sync data with device, need wait data valid.
254   void data_sync(bool need_wait = true) const;
255 
256   // brief Get the internal data object.
257   //
258   // return The reference to internal data object.
data()259   TensorData &data() { return *data_; }
260 
261   // brief Get the internal data shared pointer.
262   //
263   // return The reference to internal data object.
data_ptr()264   const TensorDataPtr &data_ptr() const { return data_; }
265 
266   // brief Get the internal data object.
267   //
268   // return The reference to internal data object.
data()269   const TensorData &data() const { return *data_; }
270 
271   TypeId set_data_type(const TypeId data_type) override;
272 
273   std::string GetShapeAndDataTypeInfo() const;
274 
275   std::string ToStringInternal(int limit_size) const;
276 
277   std::string ToStringNoLimit() const;
278 
279   std::string ToString() const override;
280 
281   std::string ToStringRepr() const;
282 
283   void CheckShape(const ShapeVector &shape) const;
284 
is_init()285   bool is_init() const { return init_flag_; }
set_init_flag(bool flag)286   void set_init_flag(bool flag) { init_flag_ = flag; }
287 
device_address()288   DeviceSyncPtr device_address() const { return device_sync_; }
289   // If need_update_ref_count is true, the device address cannot be released and reused,
290   // so the feature map should set false when set device address of tensor.
291   void set_device_address(const DeviceSyncPtr &device_sync, bool need_update_ref_count = true) {
292     device_sync_ = device_sync;
293     // To support the old and new runtime coexistence, the output of old runtime may be the input of new runtime, so the
294     // device address cannot be released through ref count and set max ref count in this scenario.
295     if (need_update_ref_count && (device_sync_ != nullptr)) {
296       device_sync_->set_original_ref_count(SIZE_MAX);
297       device_sync_->ResetRefCount();
298     }
299   }
300 
need_release_device_mem()301   bool need_release_device_mem() const { return need_release_device_mem_; }
set_need_release_device_mem(bool release_device_mem)302   void set_need_release_device_mem(bool release_device_mem) { need_release_device_mem_ = release_device_mem; }
303 
set_padding_type(const std::string padding_type)304   void set_padding_type(const std::string padding_type) { padding_type_ = padding_type; }
padding_type()305   std::string padding_type() const { return padding_type_; }
306 
id()307   std::string id() const { return id_; }
cast_dtype()308   TypePtr cast_dtype() { return cast_dtype_; }
309   void set_cast_dtype(TypePtr dtype = nullptr) { cast_dtype_ = dtype; }
310 
311   // used if cache_enable, in order to update tensor from cache to host
cache_enable()312   bool cache_enable() const { return cache_enable_; }
313   void set_cache_enable(bool cache_enable = true) { cache_enable_ = cache_enable; }
hashmap_tensor_ptr()314   std::shared_ptr<Tensor> hashmap_tensor_ptr() const { return hashmap_tensor_ptr_; }
315   void set_hashmap_tensor_ptr(std::shared_ptr<Tensor> hashmap_tensor_ptr = nullptr) {
316     hashmap_tensor_ptr_ = hashmap_tensor_ptr;
317   }
cache_tensor_ptr()318   std::shared_ptr<Tensor> cache_tensor_ptr() const { return cache_tensor_ptr_; }
319   void set_cache_tensor_ptr(std::shared_ptr<Tensor> cache_tensor_ptr = nullptr) {
320     cache_tensor_ptr_ = cache_tensor_ptr;
321   }
322 
SetNeedWait(bool need_wait)323   void SetNeedWait(bool need_wait) {
324     need_wait_ = need_wait;
325     auto event = event_;
326     if (event != nullptr) {
327       event->set_need_wait(need_wait);
328     } else if (need_wait) {
329       event_ = std::make_shared<WaitEvent>();
330       event_->set_need_wait(need_wait);
331     }
332   }
333 
NeedWait()334   bool NeedWait() const { return need_wait_; }
335 
Wait()336   void Wait() const {
337     auto event = event_;
338     if (event != nullptr) {
339       event->Wait();
340     }
341     event_ = nullptr;
342   }
343 
SetDeviceEvent(const std::shared_ptr<DeviceEvent> & device_event)344   void SetDeviceEvent(const std::shared_ptr<DeviceEvent> &device_event) { device_event_ = device_event; }
345 
WaitDevice()346   void WaitDevice() {
347     if (device_event_ != nullptr) {
348       device_event_->WaitEvent();
349     }
350   }
351 
NeedWaitDevice()352   bool NeedWaitDevice() const {
353     if (device_event_ != nullptr) {
354       return device_event_->NeedWait();
355     }
356     return false;
357   }
358 
set_sync_status(TensorSyncStatus sync_status)359   void set_sync_status(TensorSyncStatus sync_status) { sync_status_ = sync_status; }
360 
sync_status()361   TensorSyncStatus sync_status() const { return sync_status_; }
362 
NeedSyncDeviceToHostImmediately()363   bool NeedSyncDeviceToHostImmediately() const { return sync_status_ == kNeedSyncDeviceToHostImmediately; }
364 
NeedSyncDeviceToHost()365   bool NeedSyncDeviceToHost() const { return sync_status_ == kNeedSyncDeviceToHost; }
366 
NeedSyncHostToDevice()367   bool NeedSyncHostToDevice() const { return sync_status_ == kNeedSyncHostToDevice; }
368 
IsGraphOutput()369   bool IsGraphOutput() { return graph_output_; }
SetIsGraphOutput()370   void SetIsGraphOutput() { graph_output_ = true; }
IsUpdatedByDevice()371   bool IsUpdatedByDevice() { return updated_by_device_; }
SetIsUpdateByDevice()372   void SetIsUpdateByDevice() { updated_by_device_ = true; }
373 
374  private:
375   bool init_flag_{false};
376   TensorDataPtr data_{nullptr};
377   std::string id_{""};
378   mutable std::shared_ptr<WaitEvent> event_{nullptr};
379   bool need_wait_{false};
380   mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
381   bool graph_output_{false};
382   bool updated_by_device_{false};
383   DeviceSyncPtr device_sync_{nullptr};
384   // Release device address of graph output tensor or not.
385   bool need_release_device_mem_{false};
386   bool cache_enable_{false};
387   std::shared_ptr<Tensor> cache_tensor_ptr_{nullptr};
388   std::shared_ptr<Tensor> hashmap_tensor_ptr_{nullptr};
389   std::string padding_type_{""};
390   TypePtr cast_dtype_{nullptr};
391   std::shared_ptr<DeviceEvent> device_event_{nullptr};
392 };
393 using TensorPtr = std::shared_ptr<Tensor>;
394 using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
395 }  // namespace tensor
396 }  // namespace mindspore
397 
398 #endif  // MINDSPORE_CORE_IR_TENSOR_H_
399