• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <SampleDriverPartial.h>
18 #include <gtest/gtest.h>
19 
20 #include <algorithm>
21 #include <memory>
22 #include <vector>
23 
24 #include "CompilationBuilder.h"
25 #include "ExecutionPlan.h"
26 #include "HalUtils.h"
27 #include "Manager.h"
28 #include "TestNeuralNetworksWrapper.h"
29 
30 namespace android::nn {
31 namespace {
32 
33 using sample_driver::SampleDriverPartial;
34 using Result = test_wrapper::Result;
35 using WrapperOperandType = test_wrapper::OperandType;
36 using WrapperCompilation = test_wrapper::Compilation;
37 using WrapperExecution = test_wrapper::Execution;
38 using WrapperType = test_wrapper::Type;
39 using WrapperModel = test_wrapper::Model;
40 
41 class EmptyOperationResolver : public IOperationResolver {
42    public:
findOperation(OperationType) const43     const OperationRegistration* findOperation(OperationType) const override { return nullptr; }
44 };
45 
46 const char* kTestDriverName = "nnapi-test-sqrt-failing";
47 
48 // A driver that only supports SQRT and fails during execution.
49 class FailingTestDriver : public SampleDriverPartial {
50    public:
51     // EmptyOperationResolver causes execution to fail.
FailingTestDriver()52     FailingTestDriver() : SampleDriverPartial(kTestDriverName, &mEmptyOperationResolver) {}
53 
getCapabilities_1_3(getCapabilities_1_3_cb cb)54     hardware::Return<void> getCapabilities_1_3(getCapabilities_1_3_cb cb) override {
55         cb(V1_3::ErrorStatus::NONE, makeCapabilities(0.1));  // Faster than CPU.
56         return hardware::Void();
57     }
58 
59    private:
getSupportedOperationsImpl(const V1_3::Model & model) const60     std::vector<bool> getSupportedOperationsImpl(const V1_3::Model& model) const override {
61         std::vector<bool> supported(model.main.operations.size());
62         std::transform(model.main.operations.begin(), model.main.operations.end(),
63                        supported.begin(), [](const V1_3::Operation& operation) {
64                            return operation.type == V1_3::OperationType::SQRT;
65                        });
66         return supported;
67     }
68 
69     const EmptyOperationResolver mEmptyOperationResolver;
70 };
71 
72 class FailingDriverTest : public ::testing::Test {
SetUp()73     virtual void SetUp() {
74         DeviceManager* deviceManager = DeviceManager::get();
75         if (deviceManager->getUseCpuOnly() ||
76             !DeviceManager::partitioningAllowsFallback(deviceManager->getPartitioning())) {
77             GTEST_SKIP();
78         }
79         mTestDevice = DeviceManager::forTest_makeDriverDevice(
80                 makeSharedDevice(kTestDriverName, new FailingTestDriver()));
81         deviceManager->forTest_setDevices({
82                 mTestDevice,
83                 DeviceManager::getCpuDevice(),
84         });
85     }
86 
TearDown()87     virtual void TearDown() { DeviceManager::get()->forTest_reInitializeDeviceList(); }
88 
89    protected:
90     std::shared_ptr<Device> mTestDevice;
91 };
92 
93 // Regression test for b/152623150.
TEST_F(FailingDriverTest,FailAfterInterpretedWhile)94 TEST_F(FailingDriverTest, FailAfterInterpretedWhile) {
95     // Model:
96     //     f = input0
97     //     b = input1
98     //     while CAST(b):  # Identity cast.
99     //         f = CAST(f)
100     //     # FailingTestDriver fails here. When partial CPU fallback happens,
101     //     # it should not loop forever.
102     //     output0 = SQRT(f)
103 
104     WrapperOperandType floatType(WrapperType::TENSOR_FLOAT32, {2});
105     WrapperOperandType boolType(WrapperType::TENSOR_BOOL8, {1});
106 
107     WrapperModel conditionModel;
108     {
109         uint32_t f = conditionModel.addOperand(&floatType);
110         uint32_t b = conditionModel.addOperand(&boolType);
111         uint32_t out = conditionModel.addOperand(&boolType);
112         conditionModel.addOperation(ANEURALNETWORKS_CAST, {b}, {out});
113         conditionModel.identifyInputsAndOutputs({f, b}, {out});
114         ASSERT_EQ(conditionModel.finish(), Result::NO_ERROR);
115         ASSERT_TRUE(conditionModel.isValid());
116     }
117 
118     WrapperModel bodyModel;
119     {
120         uint32_t f = bodyModel.addOperand(&floatType);
121         uint32_t b = bodyModel.addOperand(&boolType);
122         uint32_t out = bodyModel.addOperand(&floatType);
123         bodyModel.addOperation(ANEURALNETWORKS_CAST, {f}, {out});
124         bodyModel.identifyInputsAndOutputs({f, b}, {out});
125         ASSERT_EQ(bodyModel.finish(), Result::NO_ERROR);
126         ASSERT_TRUE(bodyModel.isValid());
127     }
128 
129     WrapperModel model;
130     {
131         uint32_t fInput = model.addOperand(&floatType);
132         uint32_t bInput = model.addOperand(&boolType);
133         uint32_t fTmp = model.addOperand(&floatType);
134         uint32_t fSqrt = model.addOperand(&floatType);
135         uint32_t cond = model.addModelOperand(&conditionModel);
136         uint32_t body = model.addModelOperand(&bodyModel);
137         model.addOperation(ANEURALNETWORKS_WHILE, {cond, body, fInput, bInput}, {fTmp});
138         model.addOperation(ANEURALNETWORKS_SQRT, {fTmp}, {fSqrt});
139         model.identifyInputsAndOutputs({fInput, bInput}, {fSqrt});
140         ASSERT_TRUE(model.isValid());
141         ASSERT_EQ(model.finish(), Result::NO_ERROR);
142     }
143 
144     WrapperCompilation compilation(&model);
145     ASSERT_EQ(compilation.finish(), Result::NO_ERROR);
146 
147     const CompilationBuilder* compilationBuilder =
148             reinterpret_cast<CompilationBuilder*>(compilation.getHandle());
149     const ExecutionPlan& plan = compilationBuilder->forTest_getExecutionPlan();
150     const std::vector<std::shared_ptr<LogicalStep>>& steps = plan.forTest_compoundGetSteps();
151     ASSERT_EQ(steps.size(), 6u);
152     ASSERT_TRUE(steps[0]->isWhile());
153     ASSERT_TRUE(steps[1]->isExecution());
154     ASSERT_EQ(steps[1]->executionStep()->getDevice(), DeviceManager::getCpuDevice());
155     ASSERT_TRUE(steps[2]->isGoto());
156     ASSERT_TRUE(steps[3]->isExecution());
157     ASSERT_EQ(steps[3]->executionStep()->getDevice(), DeviceManager::getCpuDevice());
158     ASSERT_TRUE(steps[4]->isGoto());
159     ASSERT_TRUE(steps[5]->isExecution());
160     ASSERT_EQ(steps[5]->executionStep()->getDevice(), mTestDevice);
161 
162     WrapperExecution execution(&compilation);
163     const float fInput[] = {12 * 12, 5 * 5};
164     const bool8 bInput = false;
165     float fSqrt[] = {0, 0};
166     ASSERT_EQ(execution.setInput(0, &fInput), Result::NO_ERROR);
167     ASSERT_EQ(execution.setInput(1, &bInput), Result::NO_ERROR);
168     ASSERT_EQ(execution.setOutput(0, &fSqrt), Result::NO_ERROR);
169     ASSERT_EQ(execution.compute(), Result::NO_ERROR);
170     ASSERT_EQ(fSqrt[0], 12);
171     ASSERT_EQ(fSqrt[1], 5);
172 }
173 
174 // Regression test for b/155923033.
TEST_F(FailingDriverTest,SimplePlan)175 TEST_F(FailingDriverTest, SimplePlan) {
176     // Model:
177     //     output0 = SQRT(input0)
178     //
179     // This results in a SIMPLE execution plan. When FailingTestDriver fails,
180     // partial CPU fallback should complete the execution.
181 
182     WrapperOperandType floatType(WrapperType::TENSOR_FLOAT32, {2});
183 
184     WrapperModel model;
185     {
186         uint32_t fInput = model.addOperand(&floatType);
187         uint32_t fSqrt = model.addOperand(&floatType);
188         model.addOperation(ANEURALNETWORKS_SQRT, {fInput}, {fSqrt});
189         model.identifyInputsAndOutputs({fInput}, {fSqrt});
190         ASSERT_TRUE(model.isValid());
191         ASSERT_EQ(model.finish(), Result::NO_ERROR);
192     }
193 
194     WrapperCompilation compilation(&model);
195     ASSERT_EQ(compilation.finish(), Result::NO_ERROR);
196 
197     const CompilationBuilder* compilationBuilder =
198             reinterpret_cast<CompilationBuilder*>(compilation.getHandle());
199     const ExecutionPlan& plan = compilationBuilder->forTest_getExecutionPlan();
200     ASSERT_TRUE(plan.isSimple());
201 
202     WrapperExecution execution(&compilation);
203     const float fInput[] = {12 * 12, 5 * 5};
204     float fSqrt[] = {0, 0};
205     ASSERT_EQ(execution.setInput(0, &fInput), Result::NO_ERROR);
206     ASSERT_EQ(execution.setOutput(0, &fSqrt), Result::NO_ERROR);
207     ASSERT_EQ(execution.compute(), Result::NO_ERROR);
208     ASSERT_EQ(fSqrt[0], 12);
209     ASSERT_EQ(fSqrt[1], 5);
210 }
211 
212 }  // namespace
213 }  // namespace android::nn
214