1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_ 17 #define TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_ 18 19 #include <string> 20 21 #include "tensorflow/core/framework/device_attributes.pb.h" 22 #include "tensorflow/core/graph/graph.h" 23 #include "tensorflow/core/graph/node_builder.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/lib/gtl/array_slice.h" 26 #include "tensorflow/core/protobuf/config.pb.h" 27 28 namespace tensorflow { 29 namespace subgraph { 30 31 // Information about a graph rewritten by `RewriteGraphForExecution()`. 32 struct RewriteGraphMetadata { 33 // The element type of each tensor fed to this subgraph. The order 34 // of types corresponds to the order of tensor names in 35 // `fed_outputs` when calling `RewriteGraphForExecution()`. 36 DataTypeVector feed_types; 37 // The element type of each tensor fetched from this subgraph. The 38 // order of types corresponds to the order of tensor names in 39 // `fetch_outputs` when calling `RewriteGraphForExecution()`. 40 DataTypeVector fetch_types; 41 }; 42 43 // Describes the action to take on a particular tensor endpoint (described by 44 // a "<node_name>:<output_index>" pair) when pruning the graph. 45 // 46 // The `AddNode()` method must be overridden to describe this action. The method 47 // will be invoked once during `RewriteGraphForExecution()` with tensor endpoint 48 // named by `endpoint_name`, and it may either create a single new node, or fail 49 // with an error if the resulting graph would be invalid. 50 class PruneRewrite { 51 public: 52 // `endpoint_name` and `device_info` must outlive this object. PruneRewrite(const string * endpoint_name,const DeviceAttributes * device_info)53 PruneRewrite(const string* endpoint_name, const DeviceAttributes* device_info) 54 : endpoint_name_(endpoint_name), device_info_(device_info) {} ~PruneRewrite()55 virtual ~PruneRewrite() {} 56 57 // Creates a new node whose output replaces the given `tensor` in graph `g`. 58 // The node will be assigned to the device named in `device_info`. 59 virtual Status AddNode(Graph* g, NodeBuilder::NodeOut tensor, 60 Node** out_node) = 0; 61 62 // Returns the name of the tensor to which this rewrite applies. endpoint_name()63 const string& endpoint_name() { return *endpoint_name_; } 64 65 protected: 66 // The device on which the new node will be created. device_info()67 const DeviceAttributes& device_info() { return *device_info_; } 68 69 private: 70 const string* const endpoint_name_; // Not owned. 71 const DeviceAttributes* const device_info_; // Not owned. 72 }; 73 74 // Rewrite the graph structure of "*g" to deal with feeding node 75 // outputs, fetching node outputs, and only running a subset of the 76 // graph. "fed_outputs" and "fetch_outputs" are both lists of 77 // output tensor identifiers in the form of 78 // "<name>[:<optional_output_index>]", and "target_nodes_str" is a 79 // lists of target node names in "*g" "g". 80 // 81 // In the resulting graph "*g", output edges in "fed_outputs" have 82 // been redirected to special "_recv" nodes introduced into the graph. 83 // If these fed nodes are not needed in order to compute the effects 84 // of the nodes in "target_node_names" and "fetch_outputs", then these may 85 // be omitted from the graph. 86 // 87 // In the resulting graph "*g", additional "_send" nodes are connected 88 // to every output in "fetch_outputs". These "_send" nodes are set up 89 // to execute on the device described by device_info. 90 // 91 // On success, returns OK, and sets "*g" to a version of "*g" 92 // that represents the portions of the graph necessary for producing 93 // the output of all nodes listed in "target_node_names" and fetching the 94 // specific node outputs specified in "fetch_outputs". 95 // 96 // On failure, returns the error status. Possible errors include: 97 // - fed output "node:output_index" does not exist in "*g" 98 // - fetch output "node:output_index" does not exist in "*g" 99 // - target node "node" does not exist in "*g" 100 Status RewriteGraphForExecution( 101 Graph* g, const gtl::ArraySlice<string>& fed_outputs, 102 const gtl::ArraySlice<string>& fetch_outputs, 103 const gtl::ArraySlice<string>& target_node_names, 104 const DeviceAttributes& device_info, bool use_function_convention, 105 RewriteGraphMetadata* out_metadata); 106 107 // A more general version of the above function that supports 108 // customizable rewriting actions for each fed and fetched tensor. 109 Status RewriteGraphForExecution( 110 Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites, 111 const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites, 112 const gtl::ArraySlice<string>& target_node_names, 113 RewriteGraphMetadata* out_metadata); 114 115 ///////////////////////////////////////////////////////// 116 // Custom rewrite actions for fed and fetched tensors. // 117 ///////////////////////////////////////////////////////// 118 119 // A rewrite action that adds an _Arg node for a fed tensor. 120 class ArgFeedRewrite : public PruneRewrite { 121 public: ArgFeedRewrite(const string * endpoint_name,const DeviceAttributes * device_info,int32 arg_index)122 ArgFeedRewrite(const string* endpoint_name, 123 const DeviceAttributes* device_info, int32 arg_index) 124 : PruneRewrite(endpoint_name, device_info), arg_index_(arg_index) {} 125 Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, 126 Node** out_node) override; 127 128 private: 129 const int32 arg_index_; 130 }; 131 132 // A rewrite action that adds a client-terminated _Recv node for a fed tensor. 133 class RecvFeedRewrite : public PruneRewrite { 134 public: 135 using PruneRewrite::PruneRewrite; 136 Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, 137 Node** out_node) override; 138 }; 139 140 // A rewrite action that adds a _Retval node for a fetched tensor. 141 class RetvalFetchRewrite : public PruneRewrite { 142 public: RetvalFetchRewrite(const string * endpoint_name,const DeviceAttributes * device_info,int32 retval_index)143 RetvalFetchRewrite(const string* endpoint_name, 144 const DeviceAttributes* device_info, int32 retval_index) 145 : PruneRewrite(endpoint_name, device_info), retval_index_(retval_index) {} 146 Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, 147 Node** out_node) override; 148 149 private: 150 const int32 retval_index_; 151 }; 152 153 // A rewrite action that adds a client-terminated _Send node for a 154 // fetched tensor. 155 class SendFetchRewrite : public PruneRewrite { 156 public: 157 using PruneRewrite::PruneRewrite; 158 Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, 159 Node** out_node) override; 160 }; 161 162 } // namespace subgraph 163 } // namespace tensorflow 164 165 #endif // TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_ 166