• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_ML_NN_COMMON_CPU_OPERATION_UTILS_H
18 #define ANDROID_ML_NN_COMMON_CPU_OPERATION_UTILS_H
19 
20 #include "OperationsUtils.h"
21 
22 #include "tensorflow/contrib/lite/kernels/internal/types.h"
23 
24 namespace android {
25 namespace nn {
26 
27 // The implementations in tflite/kernels/internal/ take a Dims<4> object
28 // even if the original tensors were not 4D.
convertShapeToDims(const Shape & shape)29 inline tflite::Dims<4> convertShapeToDims(const Shape& shape) {
30   nnAssert(shape.dimensions.size() <= 4);
31   tflite::Dims<4> dims;
32 
33   // The dimensions are reversed in Dims<4>.
34   for (int i = 0; i < 4; ++i) {
35     int src = static_cast<int>(shape.dimensions.size()) - i - 1;
36     if (src >= 0) {
37       dims.sizes[i] = static_cast<int>(getSizeOfDimension(shape, src));
38     } else {
39       dims.sizes[i] = 1;
40     }
41   }
42 
43   dims.strides[0] = 1;
44   for (int i = 1; i<4; i++) {
45     dims.strides[i] = dims.strides[i-1] * dims.sizes[i-1];
46   }
47   return dims;
48 }
49 
50 } // nn
51 } // android
52 
53 #endif // ANDROID_ML_NN_COMMON_CPU_OPERATION_UTILS_H
54