• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "FibonacciDriver"
18 
19 #include "FibonacciDriver.h"
20 
21 #include "HalInterfaces.h"
22 #include "NeuralNetworksExtensions.h"
23 #include "OperationResolver.h"
24 #include "OperationsUtils.h"
25 #include "Utils.h"
26 #include "ValidateHal.h"
27 
28 #include "FibonacciExtension.h"
29 
30 namespace android {
31 namespace nn {
32 namespace sample_driver {
33 namespace {
34 
35 const uint8_t kLowBitsType = static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE);
36 const uint32_t kTypeWithinExtensionMask = (1 << kLowBitsType) - 1;
37 
38 namespace fibonacci_op {
39 
40 constexpr char kOperationName[] = "TEST_VENDOR_FIBONACCI";
41 
42 constexpr uint32_t kNumInputs = 1;
43 constexpr uint32_t kInputN = 0;
44 
45 constexpr uint32_t kNumOutputs = 1;
46 constexpr uint32_t kOutputTensor = 0;
47 
getFibonacciExtensionPrefix(const Model & model,uint16_t * prefix)48 bool getFibonacciExtensionPrefix(const Model& model, uint16_t* prefix) {
49     NN_RET_CHECK_EQ(model.extensionNameToPrefix.size(), 1u);  // Assumes no other extensions in use.
50     NN_RET_CHECK_EQ(model.extensionNameToPrefix[0].name, TEST_VENDOR_FIBONACCI_EXTENSION_NAME);
51     *prefix = model.extensionNameToPrefix[0].prefix;
52     return true;
53 }
54 
isFibonacciOperation(const Operation & operation,const Model & model)55 bool isFibonacciOperation(const Operation& operation, const Model& model) {
56     int32_t operationType = static_cast<int32_t>(operation.type);
57     uint16_t prefix;
58     NN_RET_CHECK(getFibonacciExtensionPrefix(model, &prefix));
59     NN_RET_CHECK_EQ(operationType, (prefix << kLowBitsType) | TEST_VENDOR_FIBONACCI);
60     return true;
61 }
62 
validate(const Operation & operation,const Model & model)63 bool validate(const Operation& operation, const Model& model) {
64     NN_RET_CHECK(isFibonacciOperation(operation, model));
65     NN_RET_CHECK_EQ(operation.inputs.size(), kNumInputs);
66     NN_RET_CHECK_EQ(operation.outputs.size(), kNumOutputs);
67     int32_t inputType = static_cast<int32_t>(model.operands[operation.inputs[0]].type);
68     int32_t outputType = static_cast<int32_t>(model.operands[operation.outputs[0]].type);
69     uint16_t prefix;
70     NN_RET_CHECK(getFibonacciExtensionPrefix(model, &prefix));
71     NN_RET_CHECK(inputType == ((prefix << kLowBitsType) | TEST_VENDOR_INT64) ||
72                  inputType == ANEURALNETWORKS_TENSOR_FLOAT32);
73     NN_RET_CHECK(outputType == ((prefix << kLowBitsType) | TEST_VENDOR_TENSOR_QUANT64_ASYMM) ||
74                  outputType == ANEURALNETWORKS_TENSOR_FLOAT32);
75     return true;
76 }
77 
prepare(IOperationExecutionContext * context)78 bool prepare(IOperationExecutionContext* context) {
79     int64_t n;
80     if (context->getInputType(kInputN) == OperandType::TENSOR_FLOAT32) {
81         n = static_cast<int64_t>(context->getInputValue<float>(kInputN));
82     } else {
83         n = context->getInputValue<int64_t>(kInputN);
84     }
85     NN_RET_CHECK_GE(n, 1);
86     Shape output = context->getOutputShape(kOutputTensor);
87     output.dimensions = {static_cast<uint32_t>(n)};
88     return context->setOutputShape(kOutputTensor, output);
89 }
90 
91 template <typename ScaleT, typename ZeroPointT, typename OutputT>
compute(int32_t n,ScaleT outputScale,ZeroPointT outputZeroPoint,OutputT * output)92 bool compute(int32_t n, ScaleT outputScale, ZeroPointT outputZeroPoint, OutputT* output) {
93     // Compute the Fibonacci numbers.
94     if (n >= 1) {
95         output[0] = 1;
96     }
97     if (n >= 2) {
98         output[1] = 1;
99     }
100     if (n >= 3) {
101         for (int32_t i = 2; i < n; ++i) {
102             output[i] = output[i - 1] + output[i - 2];
103         }
104     }
105 
106     // Quantize output.
107     for (int32_t i = 0; i < n; ++i) {
108         output[i] = output[i] / outputScale + outputZeroPoint;
109     }
110 
111     return true;
112 }
113 
execute(IOperationExecutionContext * context)114 bool execute(IOperationExecutionContext* context) {
115     int64_t n;
116     if (context->getInputType(kInputN) == OperandType::TENSOR_FLOAT32) {
117         n = static_cast<int64_t>(context->getInputValue<float>(kInputN));
118     } else {
119         n = context->getInputValue<int64_t>(kInputN);
120     }
121     if (context->getOutputType(kOutputTensor) == OperandType::TENSOR_FLOAT32) {
122         float* output = context->getOutputBuffer<float>(kOutputTensor);
123         return compute(n, /*scale=*/1.0, /*zeroPoint=*/0, output);
124     } else {
125         uint64_t* output = context->getOutputBuffer<uint64_t>(kOutputTensor);
126         Shape outputShape = context->getOutputShape(kOutputTensor);
127         auto outputQuant = reinterpret_cast<const TestVendorQuant64AsymmParams*>(
128                 outputShape.extraParams.extension().data());
129         return compute(n, outputQuant->scale, outputQuant->zeroPoint, output);
130     }
131 }
132 
133 }  // namespace fibonacci_op
134 }  // namespace
135 
findOperation(OperationType operationType) const136 const OperationRegistration* FibonacciOperationResolver::findOperation(
137         OperationType operationType) const {
138     // .validate is omitted because it's not used by the extension driver.
139     static OperationRegistration operationRegistration(operationType, fibonacci_op::kOperationName,
140                                                        nullptr, fibonacci_op::prepare,
141                                                        fibonacci_op::execute, {});
142     uint16_t prefix = static_cast<int32_t>(operationType) >> kLowBitsType;
143     uint16_t typeWithinExtension = static_cast<int32_t>(operationType) & kTypeWithinExtensionMask;
144     // Assumes no other extensions in use.
145     return prefix != 0 && typeWithinExtension == TEST_VENDOR_FIBONACCI ? &operationRegistration
146                                                                        : nullptr;
147 }
148 
getSupportedExtensions(getSupportedExtensions_cb cb)149 Return<void> FibonacciDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
150     cb(ErrorStatus::NONE,
151        {
152                {
153                        .name = TEST_VENDOR_FIBONACCI_EXTENSION_NAME,
154                        .operandTypes =
155                                {
156                                        {
157                                                .type = TEST_VENDOR_INT64,
158                                                .isTensor = false,
159                                                .byteSize = 8,
160                                        },
161                                        {
162                                                .type = TEST_VENDOR_TENSOR_QUANT64_ASYMM,
163                                                .isTensor = true,
164                                                .byteSize = 8,
165                                        },
166                                },
167                },
168        });
169     return Void();
170 }
171 
getCapabilities_1_2(getCapabilities_1_2_cb cb)172 Return<void> FibonacciDriver::getCapabilities_1_2(getCapabilities_1_2_cb cb) {
173     android::nn::initVLogMask();
174     VLOG(DRIVER) << "getCapabilities()";
175     static const PerformanceInfo kPerf = {.execTime = 1.0f, .powerUsage = 1.0f};
176     Capabilities capabilities = {.relaxedFloat32toFloat16PerformanceScalar = kPerf,
177                                  .relaxedFloat32toFloat16PerformanceTensor = kPerf,
178                                  .operandPerformance = nonExtensionOperandPerformance(kPerf)};
179     cb(ErrorStatus::NONE, capabilities);
180     return Void();
181 }
182 
getSupportedOperations_1_2(const V1_2::Model & model,getSupportedOperations_1_2_cb cb)183 Return<void> FibonacciDriver::getSupportedOperations_1_2(const V1_2::Model& model,
184                                                          getSupportedOperations_1_2_cb cb) {
185     VLOG(DRIVER) << "getSupportedOperations()";
186     if (!validateModel(model)) {
187         cb(ErrorStatus::INVALID_ARGUMENT, {});
188         return Void();
189     }
190     const size_t count = model.operations.size();
191     std::vector<bool> supported(count);
192     for (size_t i = 0; i < count; ++i) {
193         const Operation& operation = model.operations[i];
194         if (fibonacci_op::isFibonacciOperation(operation, model)) {
195             if (!fibonacci_op::validate(operation, model)) {
196                 cb(ErrorStatus::INVALID_ARGUMENT, {});
197                 return Void();
198             }
199             supported[i] = true;
200         }
201     }
202     cb(ErrorStatus::NONE, supported);
203     return Void();
204 }
205 
206 }  // namespace sample_driver
207 }  // namespace nn
208 }  // namespace android
209