1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include <armnn/ILayerVisitor.hpp> 8 #include <armnn/LayerVisitorBase.hpp> 9 10 #include <armnnSerializer/ISerializer.hpp> 11 12 #include <unordered_map> 13 14 #include <ArmnnSchema_generated.h> 15 16 #include <armnn/Types.hpp> 17 18 namespace armnnSerializer 19 { 20 21 class SerializerVisitor : public armnn::ILayerVisitor 22 { 23 public: SerializerVisitor()24 SerializerVisitor() : m_layerId(0) {} ~SerializerVisitor()25 ~SerializerVisitor() {} 26 GetFlatBufferBuilder()27 flatbuffers::FlatBufferBuilder& GetFlatBufferBuilder() 28 { 29 return m_flatBufferBuilder; 30 } 31 GetInputIds()32 std::vector<int>& GetInputIds() 33 { 34 return m_inputIds; 35 } 36 GetOutputIds()37 std::vector<int>& GetOutputIds() 38 { 39 return m_outputIds; 40 } 41 GetSerializedLayers()42 std::vector<flatbuffers::Offset<armnnSerializer::AnyLayer>>& GetSerializedLayers() 43 { 44 return m_serializedLayers; 45 } 46 47 flatbuffers::Offset<armnnSerializer::FeatureCompatibilityVersions> GetVersionTable(); 48 49 50 ARMNN_DEPRECATED_MSG("Use VisitElementwiseUnaryLayer instead") 51 void VisitAbsLayer(const armnn::IConnectableLayer* layer, 52 const char* name = nullptr) override; 53 54 void VisitActivationLayer(const armnn::IConnectableLayer* layer, 55 const armnn::ActivationDescriptor& descriptor, 56 const char* name = nullptr) override; 57 58 void VisitAdditionLayer(const armnn::IConnectableLayer* layer, 59 const char* name = nullptr) override; 60 61 void VisitArgMinMaxLayer(const armnn::IConnectableLayer* layer, 62 const armnn::ArgMinMaxDescriptor& argMinMaxDescriptor, 63 const char* name = nullptr) override; 64 65 void VisitBatchToSpaceNdLayer(const armnn::IConnectableLayer* layer, 66 const armnn::BatchToSpaceNdDescriptor& descriptor, 67 const char* name = nullptr) override; 68 69 void VisitBatchNormalizationLayer(const armnn::IConnectableLayer* layer, 70 const armnn::BatchNormalizationDescriptor& BatchNormalizationDescriptor, 71 const armnn::ConstTensor& mean, 72 const armnn::ConstTensor& variance, 73 const armnn::ConstTensor& beta, 74 const armnn::ConstTensor& gamma, 75 const char* name = nullptr) override; 76 77 void VisitComparisonLayer(const armnn::IConnectableLayer* layer, 78 const armnn::ComparisonDescriptor& descriptor, 79 const char* name = nullptr) override; 80 81 void VisitConcatLayer(const armnn::IConnectableLayer* layer, 82 const armnn::ConcatDescriptor& concatDescriptor, 83 const char* name = nullptr) override; 84 85 void VisitConstantLayer(const armnn::IConnectableLayer* layer, 86 const armnn::ConstTensor& input, 87 const char* = nullptr) override; 88 89 void VisitConvolution2dLayer(const armnn::IConnectableLayer* layer, 90 const armnn::Convolution2dDescriptor& descriptor, 91 const armnn::ConstTensor& weights, 92 const armnn::Optional<armnn::ConstTensor>& biases, 93 const char* = nullptr) override; 94 95 void VisitDepthToSpaceLayer(const armnn::IConnectableLayer* layer, 96 const armnn::DepthToSpaceDescriptor& descriptor, 97 const char* name = nullptr) override; 98 99 void VisitDepthwiseConvolution2dLayer(const armnn::IConnectableLayer* layer, 100 const armnn::DepthwiseConvolution2dDescriptor& descriptor, 101 const armnn::ConstTensor& weights, 102 const armnn::Optional<armnn::ConstTensor>& biases, 103 const char* name = nullptr) override; 104 105 void VisitDequantizeLayer(const armnn::IConnectableLayer* layer, 106 const char* name = nullptr) override; 107 108 void VisitDetectionPostProcessLayer(const armnn::IConnectableLayer* layer, 109 const armnn::DetectionPostProcessDescriptor& descriptor, 110 const armnn::ConstTensor& anchors, 111 const char* name = nullptr) override; 112 113 void VisitDivisionLayer(const armnn::IConnectableLayer* layer, 114 const char* name = nullptr) override; 115 116 void VisitElementwiseUnaryLayer(const armnn::IConnectableLayer* layer, 117 const armnn::ElementwiseUnaryDescriptor& descriptor, 118 const char* name = nullptr) override; 119 120 ARMNN_DEPRECATED_MSG("Use VisitComparisonLayer instead") 121 void VisitEqualLayer(const armnn::IConnectableLayer* layer, 122 const char* name = nullptr) override; 123 124 void VisitFillLayer(const armnn::IConnectableLayer* layer, 125 const armnn::FillDescriptor& fillDescriptor, 126 const char* name = nullptr) override; 127 128 void VisitFloorLayer(const armnn::IConnectableLayer *layer, 129 const char *name = nullptr) override; 130 131 void VisitFullyConnectedLayer(const armnn::IConnectableLayer* layer, 132 const armnn::FullyConnectedDescriptor& fullyConnectedDescriptor, 133 const armnn::ConstTensor& weights, 134 const armnn::Optional<armnn::ConstTensor>& biases, 135 const char* name = nullptr) override; 136 137 ARMNN_DEPRECATED_MSG("Use VisitGatherLayer with descriptor instead") 138 void VisitGatherLayer(const armnn::IConnectableLayer* layer, 139 const char* name = nullptr) override; 140 141 void VisitGatherLayer(const armnn::IConnectableLayer* layer, 142 const armnn::GatherDescriptor& gatherDescriptor, 143 const char* name = nullptr) override; 144 145 ARMNN_DEPRECATED_MSG("Use VisitComparisonLayer instead") 146 void VisitGreaterLayer(const armnn::IConnectableLayer* layer, 147 const char* name = nullptr) override; 148 149 void VisitInputLayer(const armnn::IConnectableLayer* layer, 150 armnn::LayerBindingId id, 151 const char* name = nullptr) override; 152 153 void VisitInstanceNormalizationLayer(const armnn::IConnectableLayer* layer, 154 const armnn::InstanceNormalizationDescriptor& instanceNormalizationDescriptor, 155 const char* name = nullptr) override; 156 157 void VisitL2NormalizationLayer(const armnn::IConnectableLayer* layer, 158 const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor, 159 const char* name = nullptr) override; 160 161 void VisitLogicalBinaryLayer(const armnn::IConnectableLayer* layer, 162 const armnn::LogicalBinaryDescriptor& descriptor, 163 const char* name = nullptr) override; 164 165 void VisitLogSoftmaxLayer(const armnn::IConnectableLayer* layer, 166 const armnn::LogSoftmaxDescriptor& logSoftmaxDescriptor, 167 const char* name = nullptr) override; 168 169 void VisitLstmLayer(const armnn::IConnectableLayer* layer, 170 const armnn::LstmDescriptor& descriptor, 171 const armnn::LstmInputParams& params, 172 const char* name = nullptr) override; 173 174 void VisitMeanLayer(const armnn::IConnectableLayer* layer, 175 const armnn::MeanDescriptor& descriptor, 176 const char* name) override; 177 178 void VisitMinimumLayer(const armnn::IConnectableLayer* layer, 179 const char* name = nullptr) override; 180 181 void VisitMaximumLayer(const armnn::IConnectableLayer* layer, 182 const char* name = nullptr) override; 183 184 void VisitMergeLayer(const armnn::IConnectableLayer* layer, 185 const char* name = nullptr) override; 186 187 ARMNN_DEPRECATED_MSG("Use VisitConcatLayer instead") 188 void VisitMergerLayer(const armnn::IConnectableLayer* layer, 189 const armnn::MergerDescriptor& mergerDescriptor, 190 const char* name = nullptr) override; 191 192 void VisitMultiplicationLayer(const armnn::IConnectableLayer* layer, 193 const char* name = nullptr) override; 194 195 void VisitOutputLayer(const armnn::IConnectableLayer* layer, 196 armnn::LayerBindingId id, 197 const char* name = nullptr) override; 198 199 void VisitPadLayer(const armnn::IConnectableLayer* layer, 200 const armnn::PadDescriptor& PadDescriptor, 201 const char* name = nullptr) override; 202 203 void VisitPermuteLayer(const armnn::IConnectableLayer* layer, 204 const armnn::PermuteDescriptor& PermuteDescriptor, 205 const char* name = nullptr) override; 206 207 void VisitPooling2dLayer(const armnn::IConnectableLayer* layer, 208 const armnn::Pooling2dDescriptor& pooling2dDescriptor, 209 const char* name = nullptr) override; 210 211 void VisitPreluLayer(const armnn::IConnectableLayer* layer, 212 const char* name = nullptr) override; 213 214 void VisitQuantizeLayer(const armnn::IConnectableLayer* layer, 215 const char* name = nullptr) override; 216 217 void VisitQLstmLayer(const armnn::IConnectableLayer* layer, 218 const armnn::QLstmDescriptor& descriptor, 219 const armnn::LstmInputParams& params, 220 const char* name = nullptr) override; 221 222 void VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer, 223 const armnn::QuantizedLstmInputParams& params, 224 const char* name = nullptr) override; 225 226 void VisitRankLayer(const armnn::IConnectableLayer* layer, 227 const char* name = nullptr) override; 228 229 void VisitReshapeLayer(const armnn::IConnectableLayer* layer, 230 const armnn::ReshapeDescriptor& reshapeDescriptor, 231 const char* name = nullptr) override; 232 233 void VisitResizeLayer(const armnn::IConnectableLayer* layer, 234 const armnn::ResizeDescriptor& resizeDescriptor, 235 const char* name = nullptr) override; 236 237 ARMNN_DEPRECATED_MSG("Use VisitResizeLayer instead") 238 void VisitResizeBilinearLayer(const armnn::IConnectableLayer* layer, 239 const armnn::ResizeBilinearDescriptor& resizeDescriptor, 240 const char* name = nullptr) override; 241 242 ARMNN_DEPRECATED_MSG("Use VisitElementwiseUnaryLayer instead") 243 void VisitRsqrtLayer(const armnn::IConnectableLayer* layer, 244 const char* name = nullptr) override; 245 246 void VisitSliceLayer(const armnn::IConnectableLayer* layer, 247 const armnn::SliceDescriptor& sliceDescriptor, 248 const char* name = nullptr) override; 249 250 void VisitSoftmaxLayer(const armnn::IConnectableLayer* layer, 251 const armnn::SoftmaxDescriptor& softmaxDescriptor, 252 const char* name = nullptr) override; 253 254 void VisitSpaceToBatchNdLayer(const armnn::IConnectableLayer* layer, 255 const armnn::SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor, 256 const char* name = nullptr) override; 257 258 void VisitSpaceToDepthLayer(const armnn::IConnectableLayer* layer, 259 const armnn::SpaceToDepthDescriptor& spaceToDepthDescriptor, 260 const char* name = nullptr) override; 261 262 void VisitNormalizationLayer(const armnn::IConnectableLayer* layer, 263 const armnn::NormalizationDescriptor& normalizationDescriptor, 264 const char* name = nullptr) override; 265 266 void VisitSplitterLayer(const armnn::IConnectableLayer* layer, 267 const armnn::ViewsDescriptor& viewsDescriptor, 268 const char* name = nullptr) override; 269 270 void VisitStandInLayer(const armnn::IConnectableLayer* layer, 271 const armnn::StandInDescriptor& standInDescriptor, 272 const char* name = nullptr) override; 273 274 void VisitStackLayer(const armnn::IConnectableLayer* layer, 275 const armnn::StackDescriptor& stackDescriptor, 276 const char* name = nullptr) override; 277 278 void VisitStridedSliceLayer(const armnn::IConnectableLayer* layer, 279 const armnn::StridedSliceDescriptor& stridedSliceDescriptor, 280 const char* name = nullptr) override; 281 282 void VisitSubtractionLayer(const armnn::IConnectableLayer* layer, 283 const char* name = nullptr) override; 284 285 void VisitSwitchLayer(const armnn::IConnectableLayer* layer, 286 const char* name = nullptr) override; 287 288 void VisitTransposeConvolution2dLayer(const armnn::IConnectableLayer* layer, 289 const armnn::TransposeConvolution2dDescriptor& descriptor, 290 const armnn::ConstTensor& weights, 291 const armnn::Optional<armnn::ConstTensor>& biases, 292 const char* = nullptr) override; 293 294 void VisitTransposeLayer(const armnn::IConnectableLayer* layer, 295 const armnn::TransposeDescriptor& descriptor, 296 const char* name = nullptr) override; 297 298 private: 299 300 /// Creates the Input Slots and Output Slots and LayerBase for the layer. 301 flatbuffers::Offset<armnnSerializer::LayerBase> CreateLayerBase( 302 const armnn::IConnectableLayer* layer, 303 const armnnSerializer::LayerType layerType); 304 305 /// Creates the serializer AnyLayer for the layer and adds it to m_serializedLayers. 306 void CreateAnyLayer(const flatbuffers::Offset<void>& layer, const armnnSerializer::Layer serializerLayer); 307 308 /// Creates the serializer ConstTensor for the armnn ConstTensor. 309 flatbuffers::Offset<armnnSerializer::ConstTensor> CreateConstTensorInfo( 310 const armnn::ConstTensor& constTensor); 311 312 /// Creates the serializer TensorInfo for the armnn TensorInfo. 313 flatbuffers::Offset<TensorInfo> CreateTensorInfo(const armnn::TensorInfo& tensorInfo); 314 315 template <typename T> 316 flatbuffers::Offset<flatbuffers::Vector<T>> CreateDataVector(const void* memory, unsigned int size); 317 318 ///Function which maps Guid to an index 319 uint32_t GetSerializedId(armnn::LayerGuid guid); 320 321 /// Creates the serializer InputSlots for the layer. 322 std::vector<flatbuffers::Offset<armnnSerializer::InputSlot>> CreateInputSlots( 323 const armnn::IConnectableLayer* layer); 324 325 /// Creates the serializer OutputSlots for the layer. 326 std::vector<flatbuffers::Offset<armnnSerializer::OutputSlot>> CreateOutputSlots( 327 const armnn::IConnectableLayer* layer); 328 329 /// FlatBufferBuilder to create our layers' FlatBuffers. 330 flatbuffers::FlatBufferBuilder m_flatBufferBuilder; 331 332 /// AnyLayers required by the SerializedGraph. 333 std::vector<flatbuffers::Offset<armnnSerializer::AnyLayer>> m_serializedLayers; 334 335 /// Vector of the binding ids of all Input Layers required by the SerializedGraph. 336 std::vector<int> m_inputIds; 337 338 /// Vector of the binding ids of all Output Layers required by the SerializedGraph. 339 std::vector<int> m_outputIds; 340 341 /// Mapped Guids of all Layers to match our index. 342 std::unordered_map<armnn::LayerGuid, uint32_t > m_guidMap; 343 344 /// layer within our FlatBuffer index. 345 uint32_t m_layerId; 346 }; 347 348 class Serializer : public ISerializer 349 { 350 public: Serializer()351 Serializer() {} ~Serializer()352 ~Serializer() {} 353 354 /// Serializes the network to ArmNN SerializedGraph. 355 /// @param [in] inNetwork The network to be serialized. 356 void Serialize(const armnn::INetwork& inNetwork) override; 357 358 /// Serializes the SerializedGraph to the stream. 359 /// @param [stream] the stream to save to 360 /// @return true if graph is Serialized to the Stream, false otherwise 361 bool SaveSerializedToStream(std::ostream& stream) override; 362 363 private: 364 365 /// Visitor to contruct serialized network 366 SerializerVisitor m_SerializerVisitor; 367 }; 368 369 } //namespace armnnSerializer 370