1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPPLER_ITEM_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_GRAPPLER_ITEM_H_ 18 19 #include <memory> 20 #include <string> 21 #include <unordered_set> 22 #include <utility> 23 #include <vector> 24 25 #include "tensorflow/core/framework/graph.pb.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/variable.pb.h" 28 #include "tensorflow/core/protobuf/queue_runner.pb.h" 29 30 namespace tensorflow { 31 namespace grappler { 32 33 // A TensorFlow model to optimize. 34 // Models are represented by the combination of a graph, one of more fetch 35 // nodes, and potentially a set of nodes to feed. 36 struct GrapplerItem { 37 GrapplerItem() = default; 38 GrapplerItem(const GrapplerItem& other) = default; 39 GrapplerItem(GrapplerItem&& other) = default; 40 GrapplerItem& operator=(const GrapplerItem& other) = default; 41 GrapplerItem& operator=(GrapplerItem&& other) = default; 42 virtual ~GrapplerItem() = default; 43 44 // Create a copy of this GrapplerItem with graph swapped with the argument. 45 GrapplerItem WithGraph(GraphDef&& graph) const; 46 47 string id; // A unique id for this item 48 49 // Inputs 50 GraphDef graph; 51 std::vector<std::pair<string, Tensor>> feed; 52 std::vector<string> fetch; 53 54 // Initialization op(s). 55 std::vector<string> init_ops; 56 // Expected initialization time in seconds, or 0 if unknown 57 int64 expected_init_time = 0; 58 59 // Save/restore ops (if any) 60 string save_op; 61 string restore_op; 62 string save_restore_loc_tensor; 63 64 // Queue runner(s) required to run the queue(s) of this model. 65 std::vector<QueueRunnerDef> queue_runners; 66 67 // List of op names to keep in the graph. This includes nodes that are 68 // referenced in various collections, and therefore must be preserved to 69 // ensure that the optimized metagraph can still be loaded. 70 std::vector<string> keep_ops; 71 72 // Return the set of node evaluated during a regular train/inference step. 73 std::vector<const NodeDef*> MainOpsFanin() const; 74 // Return the set of node run to populate the queues (if any). 75 std::vector<const NodeDef*> EnqueueOpsFanin() const; 76 // Return the set nodes used by TensorFlow to initialize the graph. 77 std::vector<const NodeDef*> InitOpsFanin() const; 78 // Return the set of variables accessed during a regular train/inference step. 79 std::vector<const NodeDef*> MainVariables() const; 80 // Return a set of node names that must be preserved. This includes feed and 81 // fetch nodes, keep_ops, init_ops. 82 std::unordered_set<string> NodesToPreserve() const; 83 84 struct OptimizationOptions { 85 // Is it allowed to add nodes to the graph that do not have registered 86 // gradient function. 87 bool allow_non_differentiable_rewrites = true; 88 89 // Tensorflow function execution semantics is slightly different from the 90 // main Tensorflow graph, and we need to make sure that we do not change it 91 // by running Grappler optimizer passes. One main difference is that 92 // functions do not prune ops with side-effects and dataset-output ops (see 93 // PruneFunctionBody in common_runtime/function.cc). 94 bool allow_pruning_stateful_and_dataset_ops = true; 95 96 // If true Grappler will optimize the main graph, and also all functions in 97 // the graph function library (function can't be polymorphic, it can't have 98 // undefined type parameters in the function signature, or placeholder 99 // attributes in the function body). 100 bool optimize_function_library = true; 101 }; 102 103 const std::unordered_set<string>& devices() const; 104 // Adds a device to a set of available devices, only if it's a valid fully 105 // defined device name. Returns `Status::OK()` if successfully added a device, 106 // and an error otherwise. 107 Status AddDevice(const string& device); 108 // Adds all valid devices from the other Grappler item to the device set. 109 Status AddDevices(const GrapplerItem& other); 110 // Adds all valid devices from the nodes of the graph to the device set. 111 // Returns `Status::OK()` if all device annotations found in a graph are valid 112 // fully defined device names, and an error otherwise. 113 Status InferDevicesFromGraph(); 114 // Clears a set of available devices. 115 void ClearDevices(); 116 117 const OptimizationOptions& optimization_options() const; 118 OptimizationOptions& optimization_options(); 119 120 private: 121 // TODO(ezhulenev) Make GrapplerItem a class and hide all public data members. 122 // TODO(ezhulenev): Migrate all unordered collections to absl. 123 124 // A set of fully defined device names that can be used to place the nodes of 125 // the `graph`. 126 // Example of a fully defined name: "/job:work/replica:1/task:1/device:CPU:0" 127 std::unordered_set<string> devices_; 128 129 OptimizationOptions optimization_options_; 130 }; 131 132 // Return the transitive fanin of a set of terminal nodes. 133 std::vector<const NodeDef*> ComputeTransitiveFanin( 134 const GraphDef& graph, const std::vector<string>& terminal_nodes); 135 136 // Return the transitive fanin of a set of terminal nodes. Sets 'ill_formed' to 137 // true if one of the node is missing in the graph, or some node inputs don't 138 // exist. 139 std::vector<const NodeDef*> ComputeTransitiveFanin( 140 const GraphDef& graph, const std::vector<string>& terminal_nodes, 141 bool* ill_formed); 142 143 } // end namespace grappler 144 } // end namespace tensorflow 145 146 #endif // TENSORFLOW_CORE_GRAPPLER_GRAPPLER_ITEM_H_ 147