• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_TENSOR_ARRAY_H_
18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_TENSOR_ARRAY_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include "include/backend/kernel_graph.h"
24 #include "include/backend/anf_runtime_algorithm.h"
25 #include "include/common/utils/anfalgo.h"
26 #include "kernel/kernel.h"
27 #include "include/backend/mem_reuse/mem_dynamic_allocator.h"
28 
29 namespace mindspore {
30 namespace device {
31 class BACKEND_EXPORT TensorArray {
32  public:
33   // Base TensorArray. Constructed by name, dtype and shapes.
TensorArray(const string & name,const TypePtr & dtype,const ShapeVector & shapes)34   TensorArray(const string &name, const TypePtr &dtype, const ShapeVector &shapes)
35       : name_(name), dtype_(dtype), shapes_(shapes), valid_size_(0), max_size_(0), is_dynamic_(true) {}
36   virtual ~TensorArray() = default;
37 
38   // Check the index in valid range. Used in Read().
39   virtual bool CheckReadIndexLogical(const int64_t index);
40   // Check the dtype and shape of the input data. Used in Write().
41   virtual bool CheckValue(const TypeId &dtype, const ShapeVector &shape);
42 
43   // Function Write() is used to insert or append dev_value to the position of index.
44   virtual bool Write(const int64_t index, const mindspore::kernel::AddressPtr &dev_value);
45 
46   // Function Read() can get the tensors in the scope of tensors_.
47   virtual mindspore::kernel::AddressPtr Read(const int64_t index);
48 
49   // Function Free() will release the memory in TensorArray.
50   virtual void Free();
51 
52   // These three func should by implied for different device due to the difference in memory usage.
53   // Create/Release Memory is used for malloc/free a device memory, used in function Write().
54   // ClearMemory is used to reset the input addr with zeros, used in function Free().
55   virtual void FreeMemory(const DeviceMemPtr addr) = 0;
56   virtual void *AllocateMemory(const size_t size) = 0;
57   virtual void ClearMemory(void *addr, const size_t size) = 0;
58 
59   // Clear() will only set the valid size of TensorArray to zero. The memory in TensorArray is still
60   // kept. In this situation, we can reuse the memory for next using.
61   virtual void Clear();
62 
63   // A vector of tensor address are kept in a TensorArray. For memory reusing, we will keep the addr
64   // after Clear(), in this time, the valid size will be zero but the real size still kept as
65   // tensors_.size(). Overall, using GetValidSize() to get a logical TensorArray size, and using
66   // GetRealSize() to get a physical TensorArray size.
67   virtual size_t GetValidSize() const;
68   virtual size_t GetRealSize() const;
69 
70   // This function is used in the situation that is_dynamic == false then set the max size.
71   // Otherwise, it won't be used and use the default implement.
72   virtual void SetMaxSize(const int64_t size, const bool is_dynamic);
73 
74   // Return the tensor address in position index.
75   virtual const void *GetTensorAddr(const size_t &index) const;
76 
77  protected:
78   std::string name_;
79   TypePtr dtype_;
80   ShapeVector shapes_;
81   size_t valid_size_;
82   int64_t max_size_;
83   bool is_dynamic_;
84   // Using a vector tensors_ to store the dev_tensor_addr from Write().
85   std::vector<mindspore::kernel::AddressPtr> tensors_;
86 };
87 using TensorArrayPtr = std::shared_ptr<TensorArray>;
88 }  // namespace device
89 }  // namespace mindspore
90 
91 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_TENSOR_ARRAY_H_
92