• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Contains the implementation of the operations.
18 
19 #define LOG_TAG "Operations"
20 
21 #include "Operations.h"
22 #include "OperationsUtils.h"
23 
24 #include "internal/optimized/optimized_ops.h"
25 
26 namespace android {
27 namespace nn {
28 
reshapeGeneric(const void * inputData,const Shape & inputShape,void * outputData,const Shape & outputShape)29 bool reshapeGeneric(const void* inputData, const Shape& inputShape,
30                     void* outputData, const Shape& outputShape) {
31     size_t count = sizeOfData(inputShape.type, inputShape.dimensions);
32     memcpy(outputData, inputData, count);
33     return true;
34 }
35 
resizeBilinearFloat32(const float * inputData,const Shape & inputShape,float * outputData,const Shape & outputShape)36 bool resizeBilinearFloat32(const float* inputData, const Shape& inputShape,
37                            float* outputData, const Shape& outputShape) {
38     int32_t height = (int32_t) getSizeOfDimension(outputShape, 1);
39     int32_t width  = (int32_t) getSizeOfDimension(outputShape, 2);
40 
41     int32_t outDimData[2] = {height, width};
42     // We have to fake a tensor here, to satisfy ResizeBilinear().
43     Shape outDimShape;
44     outDimShape.dimensions = {1, 1, 1, 2};
45 
46     optimized_ops::ResizeBilinear(
47             inputData, convertShapeToDims(inputShape),
48             outDimData, convertShapeToDims(outDimShape),
49             outputData, convertShapeToDims(outputShape));
50     return true;
51 }
52 
depthToSpaceGeneric(const uint8_t * inputData,const Shape & inputShape,int32_t blockSize,uint8_t * outputData,const Shape & outputShape)53 bool depthToSpaceGeneric(const uint8_t* inputData, const Shape& inputShape,
54                          int32_t blockSize,
55                          uint8_t* outputData, const Shape& outputShape) {
56     if (inputShape.type == OperandType::TENSOR_FLOAT32) {
57         optimized_ops::DepthToSpace(
58                 reinterpret_cast<const float*>(inputData),
59                 convertShapeToDims(inputShape),
60                 blockSize,
61                 reinterpret_cast<float*>(outputData),
62                 convertShapeToDims(outputShape));
63     } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
64         optimized_ops::DepthToSpace(
65                 reinterpret_cast<const uint8_t*>(inputData),
66                 convertShapeToDims(inputShape),
67                 blockSize,
68                 reinterpret_cast<uint8_t*>(outputData),
69                 convertShapeToDims(outputShape));
70     } else {
71         LOG(ERROR) << "Unsupported data type";
72         return false;
73     }
74     return true;
75 }
76 
spaceToDepthGeneric(const uint8_t * inputData,const Shape & inputShape,int32_t blockSize,uint8_t * outputData,const Shape & outputShape)77 bool spaceToDepthGeneric(const uint8_t* inputData, const Shape& inputShape,
78                          int32_t blockSize,
79                          uint8_t* outputData, const Shape& outputShape) {
80     if (inputShape.type == OperandType::TENSOR_FLOAT32) {
81         optimized_ops::SpaceToDepth(
82                 reinterpret_cast<const float*>(inputData),
83                 convertShapeToDims(inputShape),
84                 blockSize,
85                 reinterpret_cast<float*>(outputData),
86                 convertShapeToDims(outputShape));
87     } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
88         optimized_ops::SpaceToDepth(
89                 reinterpret_cast<const uint8_t*>(inputData),
90                 convertShapeToDims(inputShape),
91                 blockSize,
92                 reinterpret_cast<uint8_t*>(outputData),
93                 convertShapeToDims(outputShape));
94     } else {
95         LOG(ERROR) << "Unsupported data type";
96         return false;
97     }
98     return true;
99 }
100 
101 } // namespace nn
102 } // namespace android
103