/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Contains the implementation of the operations. #define LOG_TAG "Operations" #include #include #include #include "CpuOperationUtils.h" #include "Operations.h" #include "Tracing.h" namespace android { namespace nn { bool meanFloat16(_Float16* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape, bool keepDims, _Float16* outputData, const Shape& outputShape) { NNTRACE_TRANS("meanFloat16"); std::vector inputDataFloat32(getNumberOfElements(inputShape)); convertFloat16ToFloat32(inputData, &inputDataFloat32); std::vector outputDataFloat32(getNumberOfElements(outputShape)); meanGeneric(inputDataFloat32.data(), inputShape, axis, axisShape, keepDims, outputDataFloat32.data(), outputShape); convertFloat32ToFloat16(outputDataFloat32, outputData); return true; } template bool meanGeneric(T* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape, bool keepDims, T* outputData, const Shape& outputShape) { NNTRACE_TRANS("meanGeneric"); // Creates a temp index to iterate through input data. int32_t* scratchBuffer = new int32_t[getNumberOfDimensions(inputShape)]; // Creates a temp tensor to store resolved axis given input data. int32_t axisSize = static_cast(getSizeOfDimension(axisShape, 0)); int32_t* resolvedAxis = new int32_t[axisSize]; bool result = true; U* tempSumBuffer = new (std::nothrow) U[getNumberOfElements(outputShape)]; if (!tempSumBuffer) { LOG(ERROR) << "Failed to allocate tempSumBuffer for MEAN"; result = false; } else { NNTRACE_COMP_SWITCH("optimized_ops::Mean"); tflite::reference_ops::Mean( inputData, reinterpret_cast(inputShape.dimensions.data()), getNumberOfDimensions(inputShape), outputData, reinterpret_cast(outputShape.dimensions.data()), getNumberOfDimensions(outputShape), axis, axisSize, keepDims, scratchBuffer, resolvedAxis, tempSumBuffer); delete[] tempSumBuffer; } delete[] scratchBuffer; delete[] resolvedAxis; return result; } template bool meanGeneric(float* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape, bool keepDims, float* outputData, const Shape& outputShape); template bool meanGeneric(uint8_t* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape, bool keepDims, uint8_t* outputData, const Shape& outputShape); template bool meanGeneric(int8_t* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape, bool keepDims, int8_t* outputData, const Shape& outputShape); } // namespace nn } // namespace android