1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_ 18 19 // Operations calling functions are becoming ubiquitous in TF 2.0. 20 // Examples include PartitionedCallOp, functional If/While, and Dataset ops. 21 // Such operations might require deep inspection - looking at the body of the 22 // called function - to place them and surrounding ops correctly. 23 24 // This file contains some utilities for placer to correctly place such ops 25 // including: 26 // - PlacerInspectionRequiredOpChecker: A simple class with a single 27 // IsPlacerInspectionRequired method. 28 // - IsolatePlacerInspectionRequiredOps: This function adds Identity ops for 29 // each input/output of ops requiring placer inspection. It greatly simplifies 30 // the implementation of placing such ops. 31 32 #include <vector> 33 34 #include "absl/types/optional.h" 35 #include "tensorflow/core/framework/function.h" 36 #include "tensorflow/core/graph/graph.h" 37 #include "tensorflow/core/lib/core/status.h" 38 39 namespace tensorflow { 40 41 // PlacerInspectionRequiredOpChecker allows one to check if Placer needs to 42 // look deeply into the op to place ops consuming the outputs correctly. 43 // 44 // It is a class instead of a standalone method because checking whether 45 // a function returns a resource takes non-trivial time and we cache the 46 // results. 47 class PlacerInspectionRequiredOpChecker { 48 public: 49 // Calls the constructor below with flib_def = graph->flib_def(). 50 explicit PlacerInspectionRequiredOpChecker(const Graph* graph); 51 // Constructs a PlacerInspectionRequiredOpChecker for nodes of `graph`. 52 // The functions referenced by nodes in `graph` will be looked up in 53 // `flib_def` 54 PlacerInspectionRequiredOpChecker(const Graph* graph, 55 const FunctionLibraryDefinition* flib_def); 56 57 // If `node` is considered a deep op, sets `*is_deep` to true and returns 58 // Status::OK(). If an error occurs, returns that error, and the value of 59 // `*is_deep` is undefined. 60 // Currently, an op is considered deep, if it is a calling a function 61 // returning a resource. This definition is driven by Placer's need to 62 // look inside the op. 63 // REQUIRES: `node` is part of `graph` passed into constructor. 64 Status IsPlacerInspectionRequired(const Node& node, bool* is_deep); 65 66 private: 67 const Graph& graph_; 68 const FunctionLibraryDefinition& flib_def_; 69 // Indexed by the node id. 70 // If cache_[node_id] is empty, the deepness of the node with id `node_id` has 71 // not been computed yet. Else, it contains the value already computed. 72 std::vector<absl::optional<bool>> cache_; 73 }; 74 75 // Extracts `fdef` and `func` from `flib_def` for the function identified 76 // in "f" attribute of `node`. 77 Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def, 78 const Node& node, const FunctionDef** fdef, 79 NameAttrList* func); 80 81 // The "call" stack of functions. 82 // Useful for better error messages as well as for detecting recursion. 83 // Stores references to graph nodes. These references must outlive this. 84 class FunctionStack { 85 public: 86 explicit FunctionStack(const string& function_name); 87 88 // `node_in_current_function` must outlive this. 89 FunctionStack Push(const Node* node_in_current_function, 90 const string& new_current_function) const; 91 92 // Returns true iff this stack already includes `function_name`. 93 bool HasFunction(const string& function_name) const; 94 current_function_name()95 const string& current_function_name() const { return current_function_name_; } 96 97 // Format's this suitable for error interpolation that retrieves 98 // Python files and line numbers. 99 string FormatForError() const; 100 101 private: 102 struct Frame { FrameFrame103 Frame(const string& function, const Node* node) 104 : function_name(function), node(node) {} 105 106 string function_name; 107 const Node* node; 108 }; 109 110 // The function at the top of the stack. In other words, the function 111 // that is currently being inspected for placement. 112 string current_function_name_; 113 114 // The stack of frames that got the placement to the current_function_name_. 115 // frames_[0].function_name is the top function that Placer was constructed 116 // with. frames_[0].function_name can be empty if placer was constructed with 117 // a nameless graph, not a function. frames_[0].node_name is a name of a node 118 // in frames_[0].function_name that required deep inspection (e.g. a 119 // PartitionedCallOp). The function that this node invoked is 120 // frames_[1].function_name, if frames_.size() > 1. Else, the function that 121 // this node invoked is current_function_name_. 122 std::vector<Frame> frames_; 123 }; 124 125 // Adds Identities for each input and output of function-calling ops in `graph` 126 // 127 // For example, the following graph calling a function on inputs `a` and `b` 128 // and producing output `y` will be rewritten to include identities on all 129 // edges: 130 // 131 // a b 132 // | | 133 // v v 134 // f (PartitionedCallOp) 135 // | 136 // v 137 // y 138 // 139 // is transformed to 140 // 141 // a b 142 // | | 143 // a_f (Identity) b_f (Identity) 144 // | | 145 // v v 146 // f (PartitionedCallOp) 147 // | 148 // f_y (Identity) 149 // | 150 // v 151 // y 152 // 153 Status IsolatePlacerInspectionRequiredOps( 154 const FunctionLibraryDefinition& flib_def, Graph* graph); 155 156 } // namespace tensorflow 157 158 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_ 159