• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/backends/ITensorHandleFactory.hpp>
8 #include <aclCommon/BaseMemoryManager.hpp>
9 #include <armnn/backends/IMemoryManager.hpp>
10 #include <armnn/MemorySources.hpp>
11 
12 namespace armnn
13 {
14 
ClTensorHandleFactoryId()15 constexpr const char* ClTensorHandleFactoryId() { return "Arm/Cl/TensorHandleFactory"; }
16 
17 class ClTensorHandleFactory : public ITensorHandleFactory {
18 public:
19     static const FactoryId m_Id;
20 
ClTensorHandleFactory(std::shared_ptr<ClMemoryManager> mgr)21     ClTensorHandleFactory(std::shared_ptr<ClMemoryManager> mgr)
22                           : m_MemoryManager(mgr),
23                             m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
24                             m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined))
25         {}
26 
27     std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
28                                                          const TensorShape& subTensorShape,
29                                                          const unsigned int* subTensorOrigin) const override;
30 
31     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
32 
33     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
34                                                       DataLayout dataLayout) const override;
35 
36     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
37                                                       const bool IsMemoryManaged) const override;
38 
39     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
40                                                       DataLayout dataLayout,
41                                                       const bool IsMemoryManaged) const override;
42 
43     static const FactoryId& GetIdStatic();
44 
45     const FactoryId& GetId() const override;
46 
47     bool SupportsSubTensors() const override;
48 
49     MemorySourceFlags GetExportFlags() const override;
50 
51     MemorySourceFlags GetImportFlags() const override;
52 
53 private:
54     mutable std::shared_ptr<ClMemoryManager> m_MemoryManager;
55     MemorySourceFlags m_ImportFlags;
56     MemorySourceFlags m_ExportFlags;
57 };
58 
59 } // namespace armnn