1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H_ 16 #define TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H_ 17 18 #include <string> 19 #include <vector> 20 21 #include "tensorflow/lite/toco/model.h" 22 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h" 23 #include "tensorflow/lite/toco/tooling_util.h" 24 #include "tensorflow/core/framework/attr_value.pb.h" 25 #include "tensorflow/core/framework/graph.pb.h" 26 #include "tensorflow/core/framework/node_def.pb.h" 27 #include "tensorflow/core/framework/tensor.pb.h" 28 #include "tensorflow/core/framework/tensor_shape.pb.h" 29 30 namespace toco { 31 32 // The base class for Cluster. A cluster is group of nodes all related to each 33 // other because their name match a given "pattern", which shows they all belong 34 // to a composite op supported in TFLite. The nodes in a cluster will be 35 // collapsed into a single composite op node plus a series of constant nodes 36 // holding the input parameters to that node. The nodes in a cluster are assumed 37 // to be using the same device. By changing the "pattern" we can have different 38 // subclasses of the base Cluster class. 39 class Cluster { 40 public: ~Cluster()41 virtual ~Cluster() {} 42 43 virtual void CreateNodes() = 0; 44 45 // Save the following info from the original GraphDef this cluster is from: 46 // 1- a pointer to the GraphDef 47 // 2- All the nodes in GraphDef which belong to this cluster. 48 void SetGraphDefInfo(const tensorflow::GraphDef* graph_def); 49 GetName()50 const string& GetName() const { return name_; } 51 GetNewNodes()52 const std::vector<std::unique_ptr<tensorflow::NodeDef>>& GetNewNodes() const { 53 return new_nodes_; 54 } 55 GetNodes()56 const std::vector<const tensorflow::NodeDef*>& GetNodes() { return nodes_; } 57 SetName(const string & name)58 void SetName(const string& name) { name_ = name; } 59 SetDevice(const string & device)60 void SetDevice(const string& device) { device_ = device; } 61 62 // Find the input(s) and output(s) of this Cluster. 63 bool FindClusterInputsAndOutputs(); 64 65 protected: 66 string name_; 67 string device_; 68 std::vector<string> inputs_; 69 std::vector<string> outputs_; 70 71 // Used to hold the pointers to nodes which are in this cluster. These nodes 72 // are pointing to the nodes in graph_def_. 73 std::vector<const tensorflow::NodeDef*> nodes_; 74 75 // Used to cache the newly generated nodes: like the nodes created by 76 // collapsing Const nodes, or the nodes which is used to show the composite 77 // op. 78 std::vector<std::unique_ptr<tensorflow::NodeDef>> new_nodes_; 79 80 const tensorflow::GraphDef* graph_def_; /*Not owned*/ 81 }; 82 83 // A factory interface for cluster class. 84 // It defines a virtual function interface which is responsible for creating 85 // a cluster. Each cluster factory is responsible to pack a cluster of nodes 86 // into a cluster using a name-based pattern matching approach. 87 class ClusterFactoryInterface { 88 public: ~ClusterFactoryInterface()89 virtual ~ClusterFactoryInterface() {} 90 91 // Creates a cluster of nodes using a name-based pattern matching approach. It 92 // uses a node as a seed and if its name matches a certain pattern, then it 93 // builds the cluster around that node. 94 virtual std::unique_ptr<Cluster> CreateCluster( 95 const tensorflow::NodeDef& node, 96 const tensorflow::GraphDef& graph_def) const = 0; 97 }; 98 99 } // end namespace toco 100 101 #endif // TENSORFLOW_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H_ 102