• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "HalInterfaces.h"
20 #include "OperationResolver.h"
21 #include "OperationsUtils.h"
22 #include "Tracing.h"
23 
24 #include <cmath>
25 
26 namespace android {
27 namespace nn {
28 namespace elementwise {
29 
30 constexpr uint32_t kNumInputs = 1;
31 constexpr uint32_t kInputTensor = 0;
32 
33 constexpr uint32_t kNumOutputs = 1;
34 constexpr uint32_t kOutputTensor = 0;
35 
36 namespace {
37 
38 template <typename T>
compute(float func (float),const T * input,const Shape & shape,T * output)39 inline bool compute(float func(float), const T* input, const Shape& shape, T* output) {
40     const auto size = getNumberOfElements(shape);
41     for (uint32_t i = 0; i < size; ++i) {
42         output[i] = static_cast<T>(func(static_cast<float>(input[i])));
43     }
44     return true;
45 }
46 
execute(IOperationExecutionContext * context,float func (float))47 bool execute(IOperationExecutionContext* context, float func(float)) {
48     switch (context->getInputType(kInputTensor)) {
49         case OperandType::TENSOR_FLOAT16:
50             return compute(func, context->getInputBuffer<_Float16>(kInputTensor),
51                            context->getInputShape(kInputTensor),
52                            context->getOutputBuffer<_Float16>(kOutputTensor));
53         case OperandType::TENSOR_FLOAT32:
54             return compute(func, context->getInputBuffer<float>(kInputTensor),
55                            context->getInputShape(kInputTensor),
56                            context->getOutputBuffer<float>(kOutputTensor));
57         default:
58             NN_RET_CHECK_FAIL() << "Unsupported tensor type for elementwise operation";
59     }
60 }
61 
62 }  // namespace
63 
validate(const IOperationValidationContext * context)64 bool validate(const IOperationValidationContext* context) {
65     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
66     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
67     OperandType inputType = context->getInputType(kInputTensor);
68     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
69                  inputType == OperandType::TENSOR_FLOAT32)
70             << "Unsupported tensor type for elementwise operation";
71     NN_RET_CHECK(validateInputTypes(context, {inputType}));
72     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
73     return validateHalVersion(context, HalVersion::V1_2);
74 }
75 
prepare(IOperationExecutionContext * context)76 bool prepare(IOperationExecutionContext* context) {
77     Shape input = context->getInputShape(kInputTensor);
78     Shape output = context->getOutputShape(kOutputTensor);
79     NN_RET_CHECK(SetShape(input, &output));
80     return context->setOutputShape(kOutputTensor, output);
81 }
82 
executeAbs(IOperationExecutionContext * context)83 bool executeAbs(IOperationExecutionContext* context) {
84     return execute(context, std::abs);
85 }
86 
executeExp(IOperationExecutionContext * context)87 bool executeExp(IOperationExecutionContext* context) {
88     return execute(context, std::exp);
89 }
90 
executeLog(IOperationExecutionContext * context)91 bool executeLog(IOperationExecutionContext* context) {
92     return execute(context, std::log);
93 }
94 
executeRsqrt(IOperationExecutionContext * context)95 bool executeRsqrt(IOperationExecutionContext* context) {
96     return execute(context, [](float x) { return 1.f / std::sqrt(x); });
97 }
98 
executeSin(IOperationExecutionContext * context)99 bool executeSin(IOperationExecutionContext* context) {
100     return execute(context, std::sin);
101 }
102 
executeSqrt(IOperationExecutionContext * context)103 bool executeSqrt(IOperationExecutionContext* context) {
104     return execute(context, std::sqrt);
105 }
106 
107 }  // namespace elementwise
108 
109 NN_REGISTER_OPERATION(ABS, "ABS", elementwise::validate, elementwise::prepare,
110                       elementwise::executeAbs);
111 NN_REGISTER_OPERATION(EXP, "EXP", elementwise::validate, elementwise::prepare,
112                       elementwise::executeExp);
113 NN_REGISTER_OPERATION(LOG, "LOG", elementwise::validate, elementwise::prepare,
114                       elementwise::executeLog);
115 NN_REGISTER_OPERATION(RSQRT, "RSQRT", elementwise::validate, elementwise::prepare,
116                       elementwise::executeRsqrt);
117 NN_REGISTER_OPERATION(SIN, "SIN", elementwise::validate, elementwise::prepare,
118                       elementwise::executeSin);
119 NN_REGISTER_OPERATION(SQRT, "SQRT", elementwise::validate, elementwise::prepare,
120                       elementwise::executeSqrt);
121 
122 }  // namespace nn
123 }  // namespace android
124