• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <TestUtils.hpp>
7 
8 #include <Optimizer.hpp>
9 
10 #include <doctest/doctest.h>
11 
12 TEST_SUITE("Optimizer")
13 {
14 using namespace armnn::optimizations;
15 
16 TEST_CASE("Fp32NetworkToFp16OptimizationTest")
17 {
18     armnn::Graph graph;
19 
20     const armnn::TensorInfo infoFP32({ 2, 2, 1, 3 }, armnn::DataType::Float32);
21 
22     // Create the simple test network
23     auto input = graph.AddLayer<armnn::InputLayer>(0, "input");
24     input->GetOutputSlot().SetTensorInfo(infoFP32);
25 
26     auto floor = graph.AddLayer<armnn::FloorLayer>("floor");
27     floor->GetOutputSlot().SetTensorInfo(infoFP32);
28 
29     auto output = graph.AddLayer<armnn::OutputLayer>(1, "output");
30 
31     // Connect up the layers
32     input->GetOutputSlot().Connect(floor->GetInputSlot(0));
33     floor->GetOutputSlot().Connect(output->GetInputSlot(0));
34 
35     CHECK(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>,
36                                                       &IsLayerOfType<armnn::FloorLayer>,
37                                                       &IsLayerOfType<armnn::OutputLayer>));
38 
39     // Run the optimizer
40     armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(Fp32NetworkToFp16Converter()));
41 
42     CHECK(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>,
43                                                       &IsLayerOfType<armnn::ConvertFp32ToFp16Layer>,
44                                                       &IsLayerOfType<armnn::FloorLayer>,
45                                                       &IsLayerOfType<armnn::ConvertFp16ToFp32Layer>,
46                                                       &IsLayerOfType<armnn::OutputLayer>));
47 
48     CHECK(floor->GetDataType() == armnn::DataType::Float16);
49     CHECK(floor->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo().GetDataType() == armnn::DataType::Float16);
50     CHECK(floor->GetOutputSlot(0).GetTensorInfo().GetDataType() == armnn::DataType::Float16);
51 }
52 
53 }