1 // 2 // Copyright © 2017, 2023 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include <armnn/Descriptors.hpp> 7 #include <armnn/IRuntime.hpp> 8 #include <armnn/INetwork.hpp> 9 10 #include <doctest/doctest.h> 11 12 #include <set> 13 14 TEST_SUITE("FlowControl") 15 { 16 TEST_CASE("ErrorOnLoadNetwork") 17 { 18 using namespace armnn; 19 20 // Create runtime in which test will run 21 IRuntime::CreationOptions options; 22 IRuntimePtr runtime(IRuntime::Create(options)); 23 24 // build up the structure of the network 25 // It's equivalent to something like 26 // if (0) {} else {} 27 INetworkPtr net(INetwork::Create()); 28 29 std::vector<uint8_t> falseData = {0}; 30 ConstTensor falseTensor(armnn::TensorInfo({1}, armnn::DataType::Boolean, 0.0f, 0, true), falseData); 31 IConnectableLayer* constLayer = net->AddConstantLayer(falseTensor, "const"); 32 constLayer->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({1}, armnn::DataType::Boolean)); 33 34 IConnectableLayer* input = net->AddInputLayer(0); 35 input->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({1}, armnn::DataType::Boolean)); 36 37 IConnectableLayer* switchLayer = net->AddSwitchLayer("switch"); 38 switchLayer->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({1}, armnn::DataType::Boolean)); 39 switchLayer->GetOutputSlot(1).SetTensorInfo(armnn::TensorInfo({1}, armnn::DataType::Boolean)); 40 41 IConnectableLayer* mergeLayer = net->AddMergeLayer("merge"); 42 mergeLayer->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({1}, armnn::DataType::Boolean)); 43 44 IConnectableLayer* output = net->AddOutputLayer(0); 45 46 input->GetOutputSlot(0).Connect(switchLayer->GetInputSlot(0)); 47 constLayer->GetOutputSlot(0).Connect(switchLayer->GetInputSlot(1)); 48 switchLayer->GetOutputSlot(0).Connect(mergeLayer->GetInputSlot(0)); 49 switchLayer->GetOutputSlot(1).Connect(mergeLayer->GetInputSlot(1)); 50 mergeLayer->GetOutputSlot(0).Connect(output->GetInputSlot(0)); 51 52 // optimize the network 53 std::vector<BackendId> backends = {Compute::CpuRef}; 54 std::vector<std::string> errMessages; 55 56 try 57 { 58 Optimize(*net, backends, runtime->GetDeviceSpec(), OptimizerOptionsOpaque(), errMessages); 59 FAIL("Should have thrown an exception."); 60 } 61 catch (const InvalidArgumentException&) 62 { 63 // Different exceptions are thrown on different backends 64 } 65 CHECK(errMessages.size() > 0); 66 } 67 68 } 69