• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_TRANSFORMER_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_TRANSFORMER_H_
18 
19 #include <deque>
20 #include <string>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "tensorflow/lite/delegates/gpu/common/model.h"
25 
26 namespace tflite {
27 namespace gpu {
28 
29 class TransformationReporter;
30 
31 struct TransformationContext {
32   GraphFloat32* graph;
33   TransformationReporter* reporter;
34 };
35 
36 enum class TransformStatus {
37   // Transformation was not applied due to trivial conditions mismatch.
38   //
39   // This is different from DECLINED code below that provides in-depth
40   // explanation why a transformation that could have been applied but was not
41   // due to some issues.
42   SKIPPED,
43 
44   // Transformation was declined, therefore, a model was not modified.
45   DECLINED,
46 
47   // Transformation was applied successfully
48   APPLIED,
49 
50   // Transformation may partially be applied, but left a model in an invalid
51   // state. This error should be considered unrecoverable.
52   INVALID,
53 };
54 
55 struct TransformResult {
56   TransformStatus status;
57   std::string message;
58 };
59 
60 // Class responsible for applying a transformation to a single node.
61 class NodeTransformation {
62  public:
63   virtual ~NodeTransformation() = default;
64 
65   virtual TransformResult ApplyToNode(Node* node, GraphFloat32* graph) = 0;
66 };
67 
68 // Class responsible for applying a transformation to a sequence of nodes.
69 // Nodes are guaranteed to depend on each other without extra dependents being
70 // spilled.
71 class SequenceTransformation {
72  public:
73   virtual ~SequenceTransformation() = default;
74 
75   // @return number of nodes in a sequence to apply this transformation.
76   virtual int ExpectedSequenceLength() const = 0;
77 
78   // Applies transformations to a sequence of nodes. Transformation
79   // implementation is free manipulate with sequence nodes including adding
80   // and/or deleting nodes. if there were updates to nodes in the end and/or
81   // beginning of the sequence, then referential consistency should be
82   // maintained by updating relevant references in nodes that precede this
83   // sequence or depend on a last node of the sequence.
84   virtual TransformResult ApplyToNodesSequence(
85       const std::vector<Node*>& sequence, GraphFloat32* graph) = 0;
86 };
87 
88 // A class accumulated decisions or updates done by transformations.
89 class TransformationReporter {
90  public:
91   virtual ~TransformationReporter() = default;
92 
93   virtual void DeclinedTransformation(const std::string& transformation,
94                                       const std::string& node_ids,
95                                       const std::string& message) = 0;
96 
97   virtual void AppliedTransformation(const std::string& transformation,
98                                      const std::string& node_ids,
99                                      const std::string& message) = 0;
100 };
101 
102 // A class is designed to perform model transformations.
103 class ModelTransformer {
104  public:
ModelTransformer(GraphFloat32 * graph,TransformationReporter * reporter)105   ModelTransformer(GraphFloat32* graph, TransformationReporter* reporter)
106       : graph_(graph), reporter_(reporter) {}
107 
108   // @return false if a graph is in the broken states can not be used any more
109   bool Apply(const std::string& name, SequenceTransformation* transformation);
110 
111   // @return false if a graph is in the broken states can not be used any more
112   bool Apply(const std::string& name, NodeTransformation* transformation);
113 
114  private:
115   bool ApplyStartingWithNode(const std::string& name,
116                              SequenceTransformation* transformation,
117                              Node* begin);
118 
AddNodeToProcess(Node * node)119   void AddNodeToProcess(Node* node) {
120     if (node && processed_.insert(node->id).second) {
121       to_process_.push_back(node->id);
122     }
123   }
124 
125   GraphFloat32* graph_;
126   TransformationReporter* reporter_;
127 
128   std::deque<NodeId> to_process_;
129   std::unordered_set<NodeId> processed_;
130 };
131 
132 class NullTransformationReporter : public TransformationReporter {
133  public:
DeclinedTransformation(const std::string & transformation,const std::string & nodes_id,const std::string & message)134   void DeclinedTransformation(const std::string& transformation,
135                               const std::string& nodes_id,
136                               const std::string& message) override {}
137 
AppliedTransformation(const std::string & transformation,const std::string & nodes_id,const std::string & message)138   void AppliedTransformation(const std::string& transformation,
139                              const std::string& nodes_id,
140                              const std::string& message) override {}
141 };
142 
143 }  // namespace gpu
144 }  // namespace tflite
145 
146 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_TRANSFORMER_H_
147