1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_TRANSFORMER_H_ 17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_TRANSFORMER_H_ 18 19 #include <deque> 20 #include <string> 21 #include <utility> 22 #include <vector> 23 24 #include "absl/container/flat_hash_set.h" 25 #include "tensorflow/lite/delegates/gpu/common/model.h" 26 27 namespace tflite { 28 namespace gpu { 29 30 class TransformationReporter; 31 32 struct TransformationContext { 33 GraphFloat32* graph; 34 TransformationReporter* reporter; 35 }; 36 37 enum class TransformStatus { 38 // Transformation was not applied due to trivial conditions mismatch. 39 // 40 // This is different from DECLINED code below that provides in-depth 41 // explanation why a transformation that could have been applied but was not 42 // due to some issues. 43 SKIPPED, 44 45 // Transformation was declined, therefore, a model was not modified. 46 DECLINED, 47 48 // Transformation was applied successfully 49 APPLIED, 50 51 // Transformation may partially be applied, but left a model in an invalid 52 // state. This error should be considered unrecoverable. 53 INVALID, 54 }; 55 56 struct TransformResult { 57 TransformStatus status; 58 std::string message; 59 }; 60 61 // Class responsible for applying a transformation to a single node. 62 class NodeTransformation { 63 public: 64 virtual ~NodeTransformation() = default; 65 66 virtual TransformResult ApplyToNode(Node* node, GraphFloat32* graph) = 0; 67 }; 68 69 // Class responsible for applying a transformation to a sequence of nodes. 70 // Nodes are guaranteed to depend on each other without extra dependents being 71 // spilled. 72 class SequenceTransformation { 73 public: 74 virtual ~SequenceTransformation() = default; 75 76 // @return number of nodes in a sequence to apply this transformation. 77 virtual int ExpectedSequenceLength() const = 0; 78 79 // Applies transformations to a sequence of nodes. Transformation 80 // implementation is free manipulate with sequence nodes including adding 81 // and/or deleting nodes. if there were updates to nodes in the end and/or 82 // beginning of the sequence, then referential consistency should be 83 // maintained by updating relevant references in nodes that precede this 84 // sequence or depend on a last node of the sequence. 85 virtual TransformResult ApplyToNodesSequence( 86 const std::vector<Node*>& sequence, GraphFloat32* graph) = 0; 87 }; 88 89 // A class accumulated decisions or updates done by transformations. 90 class TransformationReporter { 91 public: 92 virtual ~TransformationReporter() = default; 93 94 virtual void DeclinedTransformation(const std::string& transformation, 95 const std::string& node_ids, 96 const std::string& message) = 0; 97 98 virtual void AppliedTransformation(const std::string& transformation, 99 const std::string& node_ids, 100 const std::string& message) = 0; 101 }; 102 103 // A class is designed to perform model transformations. 104 class ModelTransformer { 105 public: ModelTransformer(GraphFloat32 * graph,TransformationReporter * reporter)106 ModelTransformer(GraphFloat32* graph, TransformationReporter* reporter) 107 : graph_(graph), reporter_(reporter) {} 108 109 // @return false if a graph is in the broken states can not be used any more 110 bool Apply(const std::string& name, SequenceTransformation* transformation); 111 112 // @return false if a graph is in the broken states can not be used any more 113 bool Apply(const std::string& name, NodeTransformation* transformation); 114 115 private: 116 bool ApplyStartingWithNode(const std::string& name, 117 SequenceTransformation* transformation, 118 Node* begin); 119 AddNodeToProcess(Node * node)120 void AddNodeToProcess(Node* node) { 121 if (node && processed_.insert(node->id).second) { 122 to_process_.push_back(node->id); 123 } 124 } 125 126 GraphFloat32* graph_; 127 TransformationReporter* reporter_; 128 129 std::deque<NodeId> to_process_; 130 absl::flat_hash_set<NodeId> processed_; 131 }; 132 133 class NullTransformationReporter : public TransformationReporter { 134 public: DeclinedTransformation(const std::string & transformation,const std::string & nodes_id,const std::string & message)135 void DeclinedTransformation(const std::string& transformation, 136 const std::string& nodes_id, 137 const std::string& message) override {} 138 AppliedTransformation(const std::string & transformation,const std::string & nodes_id,const std::string & message)139 void AppliedTransformation(const std::string& transformation, 140 const std::string& nodes_id, 141 const std::string& message) override {} 142 }; 143 144 } // namespace gpu 145 } // namespace tflite 146 147 #endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_TRANSFORMER_H_ 148