• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <armnn/INetwork.hpp>
6 #include <armnn/IRuntime.hpp>
7 #include <armnn/Utils.hpp>
8 #include <armnn/Descriptors.hpp>
9 
10 #include <iostream>
11 
12 /// A simple example of using the ArmNN SDK API. In this sample, the users single input number is multiplied by 1.0f
13 /// using a fully connected layer with a single neuron to produce an output number that is the same as the input.
main()14 int main()
15 {
16     using namespace armnn;
17 
18     float number;
19     std::cout << "Please enter a number: " << std::endl;
20     std::cin >> number;
21 
22     // Turn on logging to standard output
23     // This is useful in this sample so that users can learn more about what is going on
24     ConfigureLogging(true, false, LogSeverity::Warning);
25 
26     // Construct ArmNN network
27     NetworkId networkIdentifier;
28     INetworkPtr myNetwork = INetwork::Create();
29 
30     float weightsData[] = {1.0f}; // Identity
31     TensorInfo weightsInfo(TensorShape({1, 1}), DataType::Float32, 0.0f, 0, true);
32     weightsInfo.SetConstant();
33     ConstTensor weights(weightsInfo, weightsData);
34 
35     // Constant layer that now holds weights data for FullyConnected
36     IConnectableLayer* const constantWeightsLayer = myNetwork->AddConstantLayer(weights, "const weights");
37 
38     FullyConnectedDescriptor fullyConnectedDesc;
39     IConnectableLayer* const fullyConnectedLayer = myNetwork->AddFullyConnectedLayer(fullyConnectedDesc,
40                                                                                      "fully connected");
41     IConnectableLayer* InputLayer  = myNetwork->AddInputLayer(0);
42     IConnectableLayer* OutputLayer = myNetwork->AddOutputLayer(0);
43 
44     InputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(0));
45     constantWeightsLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1));
46     fullyConnectedLayer->GetOutputSlot(0).Connect(OutputLayer->GetInputSlot(0));
47 
48     // Create ArmNN runtime
49     IRuntime::CreationOptions options; // default options
50     IRuntimePtr run = IRuntime::Create(options);
51 
52     //Set the tensors in the network.
53     TensorInfo inputTensorInfo(TensorShape({1, 1}), DataType::Float32);
54     InputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
55 
56     TensorInfo outputTensorInfo(TensorShape({1, 1}), DataType::Float32);
57     fullyConnectedLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
58     constantWeightsLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo);
59 
60     // Optimise ArmNN network
61     IOptimizedNetworkPtr optNet = Optimize(*myNetwork, {Compute::CpuRef}, run->GetDeviceSpec());
62     if (!optNet)
63     {
64         // This shouldn't happen for this simple sample, with reference backend.
65         // But in general usage Optimize could fail if the hardware at runtime cannot
66         // support the model that has been provided.
67         std::cerr << "Error: Failed to optimise the input network." << std::endl;
68         return 1;
69     }
70 
71     // Load graph into runtime
72     run->LoadNetwork(networkIdentifier, std::move(optNet));
73 
74     //Creates structures for inputs and outputs.
75     std::vector<float> inputData{number};
76     std::vector<float> outputData(1);
77 
78     inputTensorInfo = run->GetInputTensorInfo(networkIdentifier, 0);
79     inputTensorInfo.SetConstant(true);
80     InputTensors inputTensors{{0, armnn::ConstTensor(inputTensorInfo,
81                                                      inputData.data())}};
82     OutputTensors outputTensors{{0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0),
83                                                   outputData.data())}};
84 
85     // Execute network
86     run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
87 
88     std::cout << "Your number was " << outputData[0] << std::endl;
89     return 0;
90 
91 }
92