1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FRAMEWORKS_ML_NN_SVDF_H 18 #define FRAMEWORKS_ML_NN_SVDF_H 19 20 #include "ActivationFunctor.h" 21 22 #include <algorithm> 23 #include <cmath> 24 25 namespace android { 26 namespace hardware { 27 namespace neuralnetworks { 28 namespace V1_0 { 29 struct Operation; 30 } 31 } // namespace neuralnetworks 32 } // namespace hardware 33 } // namespace android 34 35 namespace android { 36 namespace nn { 37 38 struct SVDFParams { 39 int rank_; 40 ActivationFn activation_; 41 }; 42 43 struct RunTimeOperandInfo; 44 struct Shape; 45 46 class SVDF { 47 public: 48 SVDF(const android::hardware::neuralnetworks::V1_0::Operation &operation, 49 std::vector<RunTimeOperandInfo>& operands); 50 51 static bool Prepare( 52 const hardware::neuralnetworks::V1_0::Operation &operation, 53 std::vector<RunTimeOperandInfo> &operands, Shape *stateShape, 54 Shape *outputShape); 55 bool Eval(); 56 57 static constexpr int kInputTensor = 0; 58 static constexpr int kWeightsFeatureTensor = 1; 59 static constexpr int kWeightsTimeTensor = 2; 60 static constexpr int kBiasTensor = 3; // Optional 61 static constexpr int kStateInTensor = 4; 62 static constexpr int kRankParam = 5; 63 static constexpr int kActivationParam = 6; 64 65 static constexpr int kStateOutTensor = 0; 66 static constexpr int kOutputTensor = 1; 67 68 private: 69 SVDFParams params_; 70 71 const RunTimeOperandInfo *input_; 72 const RunTimeOperandInfo *weights_feature_; 73 const RunTimeOperandInfo *weights_time_; 74 const RunTimeOperandInfo *bias_; 75 const RunTimeOperandInfo *state_in_; 76 77 RunTimeOperandInfo *state_out_; 78 RunTimeOperandInfo *output_; 79 }; 80 81 } // namespace nn 82 } // namespace android 83 84 #endif 85