• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "MockBackendId.hpp"
9 #include "armnn/backends/profiling/IBackendProfiling.hpp"
10 #include "armnn/backends/profiling/IBackendProfilingContext.hpp"
11 
12 #include <LayerSupportCommon.hpp>
13 #include <armnn/backends/IBackendInternal.hpp>
14 #include <armnn/backends/OptimizationViews.hpp>
15 #include <armnn/backends/profiling/IBackendProfiling.hpp>
16 #include <backends/BackendProfiling.hpp>
17 #include <backendsCommon/LayerSupportBase.hpp>
18 
19 namespace armnn
20 {
21 
22 class MockBackendInitialiser
23 {
24 public:
25     MockBackendInitialiser();
26     ~MockBackendInitialiser();
27 };
28 
29 class MockBackendProfilingContext : public profiling::IBackendProfilingContext
30 {
31 public:
MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr & backendProfiling)32     MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr& backendProfiling)
33         : m_BackendProfiling(std::move(backendProfiling))
34         , m_CapturePeriod(0)
35         , m_IsTimelineEnabled(true)
36     {}
37 
38     ~MockBackendProfilingContext() = default;
39 
GetBackendProfiling()40     IBackendInternal::IBackendProfilingPtr& GetBackendProfiling()
41     {
42         return m_BackendProfiling;
43     }
44 
RegisterCounters(uint16_t currentMaxGlobalCounterId)45     uint16_t RegisterCounters(uint16_t currentMaxGlobalCounterId)
46     {
47         std::unique_ptr<profiling::IRegisterBackendCounters> counterRegistrar =
48             m_BackendProfiling->GetCounterRegistrationInterface(static_cast<uint16_t>(currentMaxGlobalCounterId));
49 
50         std::string categoryName("MockCounters");
51         counterRegistrar->RegisterCategory(categoryName);
52 
53         counterRegistrar->RegisterCounter(0, categoryName, 0, 0, 1.f, "Mock Counter One", "Some notional counter");
54 
55         counterRegistrar->RegisterCounter(1, categoryName, 0, 0, 1.f, "Mock Counter Two",
56                                                                    "Another notional counter");
57 
58         std::string units("microseconds");
59         uint16_t nextMaxGlobalCounterId =
60                 counterRegistrar->RegisterCounter(2, categoryName, 0, 0, 1.f, "Mock MultiCore Counter",
61                                                                    "A dummy four core counter", units, 4);
62         return nextMaxGlobalCounterId;
63     }
64 
ActivateCounters(uint32_t capturePeriod,const std::vector<uint16_t> & counterIds)65     Optional<std::string> ActivateCounters(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
66     {
67         if (capturePeriod == 0 || counterIds.size() == 0)
68         {
69             m_ActiveCounters.clear();
70         }
71         else if (capturePeriod == 15939u)
72         {
73             return armnn::Optional<std::string>("ActivateCounters example test error");
74         }
75         m_CapturePeriod  = capturePeriod;
76         m_ActiveCounters = counterIds;
77         return armnn::Optional<std::string>();
78     }
79 
ReportCounterValues()80     std::vector<profiling::Timestamp> ReportCounterValues()
81     {
82         std::vector<profiling::CounterValue> counterValues;
83 
84         for (auto counterId : m_ActiveCounters)
85         {
86             counterValues.emplace_back(profiling::CounterValue{ counterId, counterId + 1u });
87         }
88 
89         uint64_t timestamp = m_CapturePeriod;
90         return { profiling::Timestamp{ timestamp, counterValues } };
91     }
92 
EnableProfiling(bool)93     bool EnableProfiling(bool)
94     {
95         auto sendTimelinePacket = m_BackendProfiling->GetSendTimelinePacket();
96         sendTimelinePacket->SendTimelineEntityBinaryPacket(4256);
97         sendTimelinePacket->Commit();
98         return true;
99     }
100 
EnableTimelineReporting(bool isEnabled)101     bool EnableTimelineReporting(bool isEnabled)
102     {
103         m_IsTimelineEnabled = isEnabled;
104         return isEnabled;
105     }
106 
TimelineReportingEnabled()107     bool TimelineReportingEnabled()
108     {
109         return m_IsTimelineEnabled;
110     }
111 
112 private:
113     IBackendInternal::IBackendProfilingPtr m_BackendProfiling;
114     uint32_t m_CapturePeriod;
115     std::vector<uint16_t> m_ActiveCounters;
116     bool m_IsTimelineEnabled;
117 };
118 
119 class MockBackendProfilingService
120 {
121 public:
122     // Getter for the singleton instance
Instance()123     static MockBackendProfilingService& Instance()
124     {
125         static MockBackendProfilingService instance;
126         return instance;
127     }
128 
GetContext()129     MockBackendProfilingContext* GetContext()
130     {
131         return m_sharedContext.get();
132     }
133 
SetProfilingContextPtr(std::shared_ptr<MockBackendProfilingContext> shared)134     void SetProfilingContextPtr(std::shared_ptr<MockBackendProfilingContext> shared)
135     {
136         m_sharedContext = shared;
137     }
138 
139 private:
140     std::shared_ptr<MockBackendProfilingContext> m_sharedContext;
141 };
142 
143 class MockBackend : public IBackendInternal
144 {
145 public:
146     MockBackend()  = default;
147     ~MockBackend() = default;
148 
149     static const BackendId& GetIdStatic();
GetId() const150     const BackendId& GetId() const override
151     {
152         return GetIdStatic();
153     }
154 
155     IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override;
156 
157     IBackendInternal::IWorkloadFactoryPtr
158         CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr) const override;
159 
160     IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override;
161     IBackendInternal::IBackendProfilingContextPtr
162         CreateBackendProfilingContext(const IRuntime::CreationOptions& creationOptions,
163                                       IBackendProfilingPtr& backendProfiling) override;
164 
165     IBackendInternal::Optimizations GetOptimizations() const override;
166     IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override;
167 
168     OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph) const override;
169 };
170 
171 class MockLayerSupport : public LayerSupportBase
172 {
173 public:
IsInputSupported(const TensorInfo &,Optional<std::string &>) const174     bool IsInputSupported(const TensorInfo& /*input*/,
175                           Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
176     {
177         return true;
178     }
179 
IsOutputSupported(const TensorInfo &,Optional<std::string &>) const180     bool IsOutputSupported(const TensorInfo& /*input*/,
181                            Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
182     {
183         return true;
184     }
185 
IsAdditionSupported(const TensorInfo &,const TensorInfo &,const TensorInfo &,Optional<std::string &>) const186     bool IsAdditionSupported(const TensorInfo& /*input0*/,
187                              const TensorInfo& /*input1*/,
188                              const TensorInfo& /*output*/,
189                              Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
190     {
191         return true;
192     }
193 
IsConvolution2dSupported(const TensorInfo &,const TensorInfo &,const Convolution2dDescriptor &,const TensorInfo &,const Optional<TensorInfo> &,Optional<std::string &>) const194     bool IsConvolution2dSupported(const TensorInfo& /*input*/,
195                                   const TensorInfo& /*output*/,
196                                   const Convolution2dDescriptor& /*descriptor*/,
197                                   const TensorInfo& /*weights*/,
198                                   const Optional<TensorInfo>& /*biases*/,
199                                   Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
200     {
201         return true;
202     }
203 };
204 
205 }    // namespace armnn
206