• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "Tile.h"
20 #include "Tracing.h"
21 
22 namespace android {
23 namespace nn {
24 namespace tile {
25 
26 namespace {
27 
28 template <typename T>
CopyMultipleTimes(const T * in_data,int32_t in_size,int32_t multiplier,T * out_data)29 void CopyMultipleTimes(const T* in_data, int32_t in_size, int32_t multiplier, T* out_data) {
30     for (int i = 0; i < multiplier; ++i) {
31         const T* in_end = in_data + in_size;
32         T* new_out_data = std::copy(in_data, in_end, out_data);
33         in_data = out_data;
34         out_data = new_out_data;
35     }
36 }
37 
38 template <typename T, typename M>
TileOneDimension(const Shape & input_shape,const T * in_data,const M * multipliers,T * out_data,int dimension)39 std::pair<int, int> TileOneDimension(const Shape& input_shape, const T* in_data,
40                                      const M* multipliers, T* out_data, int dimension) {
41     const int dimension_size = input_shape.dimensions[dimension];
42     if (dimension == input_shape.dimensions.size() - 1) {
43         CopyMultipleTimes(in_data, dimension_size, multipliers[dimension], out_data);
44         return std::make_pair(dimension_size,
45                               dimension_size * static_cast<int>(multipliers[dimension]));
46     }
47     int total_stride_size = 0, total_tiled_stride_size = 0;
48     const T* copy_from_data = in_data;
49     T* copy_to_data = out_data;
50     for (int i = 0; i < dimension_size; ++i) {
51         int stride_size = 0, tiled_stride_size = 0;
52         std::tie(stride_size, tiled_stride_size) = TileOneDimension(
53                 input_shape, copy_from_data, multipliers, copy_to_data, dimension + 1);
54         copy_from_data += stride_size;
55         copy_to_data += tiled_stride_size;
56         total_stride_size += stride_size;
57         total_tiled_stride_size += tiled_stride_size;
58     }
59     CopyMultipleTimes(out_data, total_tiled_stride_size, multipliers[dimension] - 1,
60                       out_data + total_tiled_stride_size);
61     return std::make_pair(total_stride_size, total_tiled_stride_size * multipliers[dimension]);
62 }
63 
64 template <typename T>
tileImpl(const T * inputData,const Shape & inputShape,const int32_t * multiples,T * outputData,const Shape & outputShape)65 void tileImpl(const T* inputData, const Shape& inputShape, const int32_t* multiples, T* outputData,
66               const Shape& outputShape) {
67     TileOneDimension(inputShape, inputData, multiples, outputData, 0);
68 }
69 
70 }  // namespace
71 
prepare(const Shape & input,const int32_t * multiples,const Shape & multiplesShape,Shape * output)72 bool prepare(const Shape& input, const int32_t* multiples, const Shape& multiplesShape,
73              Shape* output) {
74     output->type = input.type;
75     output->offset = input.offset;
76     output->scale = input.scale;
77 
78     output->dimensions.assign(input.dimensions.begin(), input.dimensions.end());
79     for (size_t i = 0; i < output->dimensions.size(); ++i) {
80         output->dimensions[i] *= multiples[i];
81     }
82 
83     return true;
84 }
85 
eval(const uint8_t * inputData,const Shape & inputShape,const int32_t * multiples,uint8_t * outputData,const Shape & outputShape)86 bool eval(const uint8_t* inputData, const Shape& inputShape, const int32_t* multiples,
87           uint8_t* outputData, const Shape& outputShape) {
88     NNTRACE_TRANS("tile::eval");
89 #define ANDROID_NN_IMPL_TILE(operandType, dataType)                                   \
90     case operandType: {                                                               \
91         NNTRACE_COMP_SWITCH("tileImpl::" #dataType);                                  \
92         tileImpl(reinterpret_cast<const dataType*>(inputData), inputShape, multiples, \
93                  reinterpret_cast<dataType*>(outputData), outputShape);               \
94         return true;                                                                  \
95     }
96 
97     switch (inputShape.type) {
98         ANDROID_NN_IMPL_TILE(OperandType::TENSOR_FLOAT16, _Float16);
99         ANDROID_NN_IMPL_TILE(OperandType::TENSOR_FLOAT32, float);
100         ANDROID_NN_IMPL_TILE(OperandType::TENSOR_INT32, int32_t);
101         ANDROID_NN_IMPL_TILE(OperandType::TENSOR_QUANT8_ASYMM, uint8_t);
102         default:
103             LOG(ERROR) << "Unsupported data type";
104             return false;
105     }
106 #undef ANDROID_NN_IMPL_TILE
107 }
108 
109 }  // namespace tile
110 }  // namespace nn
111 }  // namespace android
112