• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_TYPES_H_
18 #define ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_TYPES_H_
19 
20 #include "compatibility.h"
21 
22 namespace android {
23 namespace nn {
24 
25 enum class FusedActivationFunctionType { kNone, kRelu6, kRelu1, kRelu };
26 
27 template <int N>
28 struct Dims {
29   int sizes[N];
30   int strides[N];
31 };
32 
33 struct Shape;
34 
convertShapeToDims(const Shape & shape)35 inline Dims<4> convertShapeToDims(const Shape& shape) {
36   Dims<4> dims;
37   for (int i=0; i<4; i++) {
38     dims.sizes[i] = 1;
39   }
40 
41   if (shape.dimensions.size() == 1) {
42     dims.sizes[0] = (int)getSizeOfDimension(shape, 0);
43   } else {
44     for (int i=0; i<4; i++) {
45       int src = (int)shape.dimensions.size()-i-1;
46       if (src >= 0) {
47         dims.sizes[i] = (int)getSizeOfDimension(shape, src);
48       }
49     }
50   }
51 
52   dims.strides[0] = 1;
53   for (int i = 1; i<4; i++) {
54     dims.strides[i] = dims.strides[i-1] * dims.sizes[i-1];
55   }
56   return dims;
57 }
58 
Offset(const Dims<4> & dims,int i0,int i1,int i2,int i3)59 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
60   DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
61   DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
62   DCHECK(i2 >= 0 && i2 < dims.sizes[2]);
63   DCHECK(i3 >= 0 && i3 < dims.sizes[3]);
64   return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
65          i3 * dims.strides[3];
66 }
67 
68 // Get array size, DCHECKing that the dim index is in range.
69 template <int N>
ArraySize(const Dims<N> & array,int index)70 int ArraySize(const Dims<N>& array, int index) {
71   DCHECK(index >= 0 && index < N);
72   return array.sizes[index];
73 }
74 
75 // Get common array size, DCHECKing that they all agree.
76 template <typename ArrayType1, typename ArrayType2>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2)77 int MatchingArraySize(const ArrayType1& array1, int index1,
78                       const ArrayType2& array2, int index2) {
79   DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
80   return ArraySize(array1, index1);
81 }
82 
83 template <typename ArrayType1, typename ArrayType2, typename... Args>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2,Args...args)84 int MatchingArraySize(const ArrayType1& array1, int index1,
85                       const ArrayType2& array2, int index2, Args... args) {
86   DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
87   return MatchingArraySize(array1, index1, args...);
88 }
89 
RequiredBufferSizeForDims(const Dims<4> & dims)90 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
91   int max_offset = 0;
92   for (int i = 0; i < 4; i++) {
93     max_offset += (dims.sizes[i] - 1) * dims.strides[i];
94   }
95   return max_offset + 1;
96 }
97 
98 template <int N>
IsPackedWithoutStrides(const Dims<N> & dims)99 bool IsPackedWithoutStrides(const Dims<N>& dims) {
100   int expected_stride = 1;
101   for (int d = 0; d < N; d++) {
102     if (dims.strides[d] != expected_stride) return false;
103     expected_stride *= dims.sizes[d];
104   }
105   return true;
106 }
107 
108 }  // namespace nn
109 }  // namespace android
110 
111 #endif  // ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_TYPES_H_
112