• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Contains the implementation of the operations.
18 
19 #define LOG_TAG "Operations"
20 
21 #include "CpuOperationUtils.h"
22 #include "Operations.h"
23 
24 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h"
25 #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h"
26 
27 #include "Tracing.h"
28 
29 namespace android {
30 namespace nn {
31 
floorFloat16(const _Float16 * inputData,_Float16 * outputData,const Shape & shape)32 bool floorFloat16(const _Float16* inputData, _Float16* outputData, const Shape& shape) {
33     NNTRACE_TRANS("floorFloat16");
34     std::vector<float> inputDataFloat32(getNumberOfElements(shape));
35     convertFloat16ToFloat32(inputData, &inputDataFloat32);
36 
37     std::vector<float> outputDataFloat32(getNumberOfElements(shape));
38     floorFloat32(inputDataFloat32.data(), outputDataFloat32.data(), shape);
39     convertFloat32ToFloat16(outputDataFloat32, outputData);
40     return true;
41 }
42 
floorFloat32(const float * inputData,float * outputData,const Shape & shape)43 bool floorFloat32(const float* inputData, float* outputData, const Shape& shape) {
44     NNTRACE_TRANS("floorFloat32");
45     tflite::Dims<4> dim = convertShapeToDims(shape);
46     NNTRACE_COMP_SWITCH("optimized_ops::Floor");
47     tflite::optimized_ops::Floor(inputData, dim, outputData, dim);
48     return true;
49 }
50 
meanFloat16(_Float16 * inputData,const Shape & inputShape,const int32_t * axis,const Shape & axisShape,bool keepDims,_Float16 * outputData,const Shape & outputShape)51 bool meanFloat16(_Float16* inputData, const Shape& inputShape, const int32_t* axis,
52                  const Shape& axisShape, bool keepDims, _Float16* outputData,
53                  const Shape& outputShape) {
54     NNTRACE_TRANS("meanFloat16");
55     std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
56     convertFloat16ToFloat32(inputData, &inputDataFloat32);
57 
58     std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
59     meanGeneric<float, float>(inputDataFloat32.data(), inputShape, axis, axisShape, keepDims,
60                               outputDataFloat32.data(), outputShape);
61     convertFloat32ToFloat16(outputDataFloat32, outputData);
62     return true;
63 }
64 
65 template <typename T, typename U>
meanGeneric(T * inputData,const Shape & inputShape,const int32_t * axis,const Shape & axisShape,bool keepDims,T * outputData,const Shape & outputShape)66 bool meanGeneric(T* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape,
67                  bool keepDims, T* outputData, const Shape& outputShape) {
68     NNTRACE_TRANS("meanGeneric");
69     // Creates a temp index to iterate through input data.
70     int32_t* scratchBuffer = new int32_t[getNumberOfDimensions(inputShape)];
71 
72     // Creates a temp tensor to store resolved axis given input data.
73     int32_t axisSize = static_cast<int32_t>(getSizeOfDimension(axisShape, 0));
74     int32_t* resolvedAxis = new int32_t[axisSize];
75 
76     bool result = true;
77     U* tempSumBuffer = new (std::nothrow) U[getNumberOfElements(outputShape)];
78     if (!tempSumBuffer) {
79         LOG(ERROR) << "Failed to allocate tempSumBuffer for MEAN";
80         result = false;
81     } else {
82         NNTRACE_COMP_SWITCH("optimized_ops::Mean");
83         tflite::reference_ops::Mean<T, U>(
84                 inputData, reinterpret_cast<const int*>(inputShape.dimensions.data()),
85                 getNumberOfDimensions(inputShape), outputData,
86                 reinterpret_cast<const int*>(outputShape.dimensions.data()),
87                 getNumberOfDimensions(outputShape), axis, axisSize, keepDims, scratchBuffer,
88                 resolvedAxis, tempSumBuffer);
89         delete[] tempSumBuffer;
90     }
91     delete[] scratchBuffer;
92     delete[] resolvedAxis;
93     return result;
94 }
95 template bool meanGeneric<float, float>(float* inputData, const Shape& inputShape,
96                                         const int32_t* axis, const Shape& axisShape, bool keepDims,
97                                         float* outputData, const Shape& outputShape);
98 template bool meanGeneric<uint8_t, int32_t>(uint8_t* inputData, const Shape& inputShape,
99                                             const int32_t* axis, const Shape& axisShape,
100                                             bool keepDims, uint8_t* outputData,
101                                             const Shape& outputShape);
102 
103 }  // namespace nn
104 }  // namespace android
105