1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_TYPES_H_
18 #define ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_TYPES_H_
19
20 #include "compatibility.h"
21
22 namespace android {
23 namespace nn {
24
25 enum class FusedActivationFunctionType { kNone, kRelu6, kRelu1, kRelu };
26
27 template <int N>
28 struct Dims {
29 int sizes[N];
30 int strides[N];
31 };
32
33 struct Shape;
34
convertShapeToDims(const Shape & shape)35 inline Dims<4> convertShapeToDims(const Shape& shape) {
36 Dims<4> dims;
37 for (int i=0; i<4; i++) {
38 dims.sizes[i] = 1;
39 }
40
41 if (shape.dimensions.size() == 1) {
42 dims.sizes[0] = (int)getSizeOfDimension(shape, 0);
43 } else {
44 for (int i=0; i<4; i++) {
45 int src = (int)shape.dimensions.size()-i-1;
46 if (src >= 0) {
47 dims.sizes[i] = (int)getSizeOfDimension(shape, src);
48 }
49 }
50 }
51
52 dims.strides[0] = 1;
53 for (int i = 1; i<4; i++) {
54 dims.strides[i] = dims.strides[i-1] * dims.sizes[i-1];
55 }
56 return dims;
57 }
58
Offset(const Dims<4> & dims,int i0,int i1,int i2,int i3)59 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
60 DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
61 DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
62 DCHECK(i2 >= 0 && i2 < dims.sizes[2]);
63 DCHECK(i3 >= 0 && i3 < dims.sizes[3]);
64 return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
65 i3 * dims.strides[3];
66 }
67
68 // Get array size, DCHECKing that the dim index is in range.
69 template <int N>
ArraySize(const Dims<N> & array,int index)70 int ArraySize(const Dims<N>& array, int index) {
71 DCHECK(index >= 0 && index < N);
72 return array.sizes[index];
73 }
74
75 // Get common array size, DCHECKing that they all agree.
76 template <typename ArrayType1, typename ArrayType2>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2)77 int MatchingArraySize(const ArrayType1& array1, int index1,
78 const ArrayType2& array2, int index2) {
79 DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
80 return ArraySize(array1, index1);
81 }
82
83 template <typename ArrayType1, typename ArrayType2, typename... Args>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2,Args...args)84 int MatchingArraySize(const ArrayType1& array1, int index1,
85 const ArrayType2& array2, int index2, Args... args) {
86 DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
87 return MatchingArraySize(array1, index1, args...);
88 }
89
RequiredBufferSizeForDims(const Dims<4> & dims)90 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
91 int max_offset = 0;
92 for (int i = 0; i < 4; i++) {
93 max_offset += (dims.sizes[i] - 1) * dims.strides[i];
94 }
95 return max_offset + 1;
96 }
97
98 template <int N>
IsPackedWithoutStrides(const Dims<N> & dims)99 bool IsPackedWithoutStrides(const Dims<N>& dims) {
100 int expected_stride = 1;
101 for (int d = 0; d < N; d++) {
102 if (dims.strides[d] != expected_stride) return false;
103 expected_stride *= dims.sizes[d];
104 }
105 return true;
106 }
107
108 } // namespace nn
109 } // namespace android
110
111 #endif // ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_TYPES_H_
112