1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "RefMemoryManager.hpp" 9 10 #include <armnn/backends/ITensorHandleFactory.hpp> 11 12 namespace armnn 13 { 14 RefTensorHandleFactoryId()15constexpr const char * RefTensorHandleFactoryId() { return "Arm/Ref/TensorHandleFactory"; } 16 17 class RefTensorHandleFactory : public ITensorHandleFactory 18 { 19 20 public: RefTensorHandleFactory(std::shared_ptr<RefMemoryManager> mgr)21 RefTensorHandleFactory(std::shared_ptr<RefMemoryManager> mgr) 22 : m_MemoryManager(mgr), 23 m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)), 24 m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)) 25 {} 26 27 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent, 28 TensorShape const& subTensorShape, 29 unsigned int const* subTensorOrigin) const override; 30 31 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override; 32 33 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 34 DataLayout dataLayout) const override; 35 36 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 37 const bool IsMemoryManaged) const override; 38 39 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 40 DataLayout dataLayout, 41 const bool IsMemoryManaged) const override; 42 43 static const FactoryId& GetIdStatic(); 44 45 const FactoryId& GetId() const override; 46 47 bool SupportsSubTensors() const override; 48 49 MemorySourceFlags GetExportFlags() const override; 50 51 MemorySourceFlags GetImportFlags() const override; 52 53 private: 54 mutable std::shared_ptr<RefMemoryManager> m_MemoryManager; 55 MemorySourceFlags m_ImportFlags; 56 MemorySourceFlags m_ExportFlags; 57 58 }; 59 60 } // namespace armnn 61 62