1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "HalInterfaces.h"
20 #include "OperationResolver.h"
21 #include "OperationsUtils.h"
22 #include "Tracing.h"
23
24 #include <cmath>
25
26 namespace android {
27 namespace nn {
28 namespace elementwise {
29
30 constexpr uint32_t kNumInputs = 1;
31 constexpr uint32_t kInputTensor = 0;
32
33 constexpr uint32_t kNumOutputs = 1;
34 constexpr uint32_t kOutputTensor = 0;
35
36 namespace {
37
38 template <typename T>
compute(float func (float),const T * input,const Shape & shape,T * output)39 inline bool compute(float func(float), const T* input, const Shape& shape, T* output) {
40 const auto size = getNumberOfElements(shape);
41 for (uint32_t i = 0; i < size; ++i) {
42 output[i] = static_cast<T>(func(static_cast<float>(input[i])));
43 }
44 return true;
45 }
46
execute(IOperationExecutionContext * context,float func (float))47 bool execute(IOperationExecutionContext* context, float func(float)) {
48 switch (context->getInputType(kInputTensor)) {
49 case OperandType::TENSOR_FLOAT16:
50 return compute(func, context->getInputBuffer<_Float16>(kInputTensor),
51 context->getInputShape(kInputTensor),
52 context->getOutputBuffer<_Float16>(kOutputTensor));
53 case OperandType::TENSOR_FLOAT32:
54 return compute(func, context->getInputBuffer<float>(kInputTensor),
55 context->getInputShape(kInputTensor),
56 context->getOutputBuffer<float>(kOutputTensor));
57 default:
58 NN_RET_CHECK_FAIL() << "Unsupported tensor type for elementwise operation";
59 }
60 }
61
62 } // namespace
63
validate(const IOperationValidationContext * context)64 bool validate(const IOperationValidationContext* context) {
65 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
66 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
67 OperandType inputType = context->getInputType(kInputTensor);
68 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
69 inputType == OperandType::TENSOR_FLOAT32)
70 << "Unsupported tensor type for elementwise operation";
71 NN_RET_CHECK(validateInputTypes(context, {inputType}));
72 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
73 return validateHalVersion(context, HalVersion::V1_2);
74 }
75
prepare(IOperationExecutionContext * context)76 bool prepare(IOperationExecutionContext* context) {
77 Shape input = context->getInputShape(kInputTensor);
78 Shape output = context->getOutputShape(kOutputTensor);
79 NN_RET_CHECK(SetShape(input, &output));
80 return context->setOutputShape(kOutputTensor, output);
81 }
82
executeAbs(IOperationExecutionContext * context)83 bool executeAbs(IOperationExecutionContext* context) {
84 return execute(context, std::abs);
85 }
86
executeExp(IOperationExecutionContext * context)87 bool executeExp(IOperationExecutionContext* context) {
88 return execute(context, std::exp);
89 }
90
executeLog(IOperationExecutionContext * context)91 bool executeLog(IOperationExecutionContext* context) {
92 return execute(context, std::log);
93 }
94
executeRsqrt(IOperationExecutionContext * context)95 bool executeRsqrt(IOperationExecutionContext* context) {
96 return execute(context, [](float x) { return 1.f / std::sqrt(x); });
97 }
98
executeSin(IOperationExecutionContext * context)99 bool executeSin(IOperationExecutionContext* context) {
100 return execute(context, std::sin);
101 }
102
executeSqrt(IOperationExecutionContext * context)103 bool executeSqrt(IOperationExecutionContext* context) {
104 return execute(context, std::sqrt);
105 }
106
107 } // namespace elementwise
108
109 NN_REGISTER_OPERATION(ABS, "ABS", elementwise::validate, elementwise::prepare,
110 elementwise::executeAbs);
111 NN_REGISTER_OPERATION(EXP, "EXP", elementwise::validate, elementwise::prepare,
112 elementwise::executeExp);
113 NN_REGISTER_OPERATION(LOG, "LOG", elementwise::validate, elementwise::prepare,
114 elementwise::executeLog);
115 NN_REGISTER_OPERATION(RSQRT, "RSQRT", elementwise::validate, elementwise::prepare,
116 elementwise::executeRsqrt);
117 NN_REGISTER_OPERATION(SIN, "SIN", elementwise::validate, elementwise::prepare,
118 elementwise::executeSin);
119 NN_REGISTER_OPERATION(SQRT, "SQRT", elementwise::validate, elementwise::prepare,
120 elementwise::executeSqrt);
121
122 } // namespace nn
123 } // namespace android
124