1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_LSTM_NORMALIZATION_H_ 17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_LSTM_NORMALIZATION_H_ 18 19 #include <map> 20 #include <set> 21 #include <string> 22 #include <vector> 23 24 #include "tensorflow/lite/delegates/gpu/common/model.h" 25 #include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h" 26 #include "tensorflow/lite/delegates/gpu/common/status.h" 27 #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" 28 #include "tensorflow/lite/delegates/gpu/common/types.h" 29 30 namespace tflite { 31 namespace gpu { 32 33 // Implements tensor_utils::MeanStddevNormalization 34 class MeanStdDevNormalization : public GPUOperation { 35 public: 36 explicit MeanStdDevNormalization(const OperationDef& definition, 37 const GpuInfo& gpu_info, const BHWC& shape, 38 float variance_bias, bool two_step); 39 GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups)40 void GetPossibleKernelWorkGroups( 41 TuningType tuning_type, const GpuInfo& gpu_info, 42 const KernelInfo& kernel_info, 43 std::vector<int3>* work_groups) const override { 44 work_groups->push_back(work_group_size_); 45 } 46 int3 GetGridSize() const override; 47 48 // Move only 49 MeanStdDevNormalization(MeanStdDevNormalization&& kernel) = default; 50 MeanStdDevNormalization& operator=(MeanStdDevNormalization&& kernel) = 51 default; 52 MeanStdDevNormalization(const MeanStdDevNormalization&) = delete; 53 MeanStdDevNormalization& operator=(const MeanStdDevNormalization&) = delete; 54 55 private: 56 std::string GetNormalizationCode(const GpuInfo& gpu_info, bool channels_x4, 57 bool two_step); 58 }; 59 60 // std dev can be calculated in single step, but two step algorithm can 61 // provide more stable and robust results 62 MeanStdDevNormalization CreateMeanStdDevNormalization( 63 const OperationDef& definition, const GpuInfo& gpu_info, const BHWC& shape, 64 float variance_bias = 1.0e-8f, bool two_step = true); 65 66 // MeanStdDevNormalization fusion works with this subgraph 67 // input 68 // / \ 69 // | mean 70 // \ / 71 // substraction 72 // / \ 73 // | | 74 // | square 75 // | | 76 // | mean 77 // | | 78 // | add 79 // | | 80 // | rsqrt 81 // | | 82 // \ / 83 // multiplication 84 // | 85 // output 86 absl::Status TryMeanStdDevNormalization( 87 const GpuInfo& gpu_info, CalculationsPrecision precision, 88 const GraphFloat32& graph, NodeId first_node_id, 89 const std::map<ValueId, TensorDescriptor>& tensor_descriptors, 90 std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph); 91 92 } // namespace gpu 93 } // namespace tflite 94 95 #endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_LSTM_NORMALIZATION_H_ 96