1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ 17 18 #include <vector> 19 20 #include "tensorflow/core/common_runtime/graph_runner.h" 21 #include "tensorflow/core/framework/function.pb.h" 22 #include "tensorflow/core/framework/shape_inference.h" 23 #include "tensorflow/core/graph/graph.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/platform/macros.h" 26 27 namespace tensorflow { 28 namespace grappler { 29 class GraphProperties; 30 } 31 32 // This class stores extra inference information in addition to 33 // InferenceContext, such as inference tree for user-defined functions and node 34 // input and output types. 35 class ExtendedInferenceContext { 36 public: ExtendedInferenceContext(std::unique_ptr<shape_inference::InferenceContext> ic,const Node * node)37 ExtendedInferenceContext( 38 std::unique_ptr<shape_inference::InferenceContext> ic, const Node* node) 39 : inference_context_(std::move(ic)) { 40 input_types_.reserve(node->num_inputs()); 41 for (int i = 0; i < node->num_inputs(); i++) { 42 input_types_.push_back(node->input_type(i)); 43 } 44 output_types_.reserve(node->num_outputs()); 45 for (int i = 0; i < node->num_outputs(); i++) { 46 output_types_.push_back(node->output_type(i)); 47 } 48 } 49 50 const std::unordered_map<string, std::unique_ptr<ExtendedInferenceContext>>& nested_inferences()51 nested_inferences() const { 52 return nested_inferences_; 53 } input_type(int64 idx)54 DataType input_type(int64 idx) const { return input_types_[idx]; } output_type(int64 idx)55 DataType output_type(int64 idx) const { return output_types_[idx]; } 56 get_context()57 shape_inference::InferenceContext* get_context() { 58 return inference_context_.get(); 59 } 60 61 // Sets nested inference info. 62 // For composite ops (user-defined functions) only. 63 // Inference for trivial ops must not call this setter. set_nested_inferences(std::unordered_map<string,std::unique_ptr<ExtendedInferenceContext>> inferences)64 void set_nested_inferences( 65 std::unordered_map<string, std::unique_ptr<ExtendedInferenceContext>> 66 inferences) { 67 nested_inferences_ = std::move(inferences); 68 } 69 70 private: 71 std::unique_ptr<shape_inference::InferenceContext> inference_context_; 72 std::vector<DataType> input_types_; 73 std::vector<DataType> output_types_; 74 75 // Nested inferences for composite ops (user-defined functions). 76 // Mapping key is nested node name. 77 // For trivial ops this map must be empty. 78 std::unordered_map<string, std::unique_ptr<ExtendedInferenceContext>> 79 nested_inferences_; 80 81 TF_DISALLOW_COPY_AND_ASSIGN(ExtendedInferenceContext); 82 }; 83 84 // ShapeRefiner performs shape inference for TensorFlow Graphs. It is 85 // responsible for instantiating InferenceContext objects for each 86 // Node in the Graph, and providing/storing the 'input_tensor' Tensors 87 // used by Shape Inference functions, when available at graph 88 // construction time. 89 class ShapeRefiner { 90 public: 91 ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops); 92 93 // Same as ShapeRefiner(versions.producer(), ops) 94 ShapeRefiner(const VersionDef& versions, const OpRegistryInterface* ops); 95 96 ~ShapeRefiner(); 97 98 // Performs validation of 'node' and runs 'node's shape function, 99 // storing its shape outputs. 100 // 101 // All inputs of 'node' must be added to ShapeRefiner prior to 102 // adding 'node'. 103 // 104 // Returns an error if: 105 // - the shape function for 'node' was not registered. 106 // - 'node' was added before its inputs. 107 // - The shape inference function returns an error. 108 Status AddNode(const Node* node); 109 110 // Sets 'node's 'output_port' output to have shape 'shape'. 111 // 112 // Returns an error if 'node' was not previously added to this 113 // object, if 'output_port' is invalid, or if 'shape' is 114 // not compatible with the existing shape of the output. 115 Status SetShape(const Node* node, int output_port, 116 shape_inference::ShapeHandle shape); 117 118 // Update the input shapes of node in case the shapes of the fan-ins of 'node' 119 // have themselves been modified (For example, in case of incremental shape 120 // refinement). If 'relax' is true, a new shape with the broadest set of 121 // information will be set as the new input (see InferenceContext::RelaxInput 122 // for full details and examples). Sets refined to true if any shapes have 123 // changed (in their string representations). Note that shapes may have been 124 // updated to newer versions (but with identical string representations) even 125 // if <*refined> is set to false. 126 Status UpdateNode(const Node* node, bool relax, bool* refined); 127 128 // Returns the InferenceContext for 'node', if present. GetContext(const Node * node)129 shape_inference::InferenceContext* GetContext(const Node* node) const { 130 auto it = node_to_context_.find(node); 131 if (it == node_to_context_.end()) { 132 return nullptr; 133 } 134 return it->second->get_context(); 135 } 136 137 // Returns the ExtendedInferenceContext for 'node', if present. GetExtendedContext(const Node * node)138 ExtendedInferenceContext* GetExtendedContext(const Node* node) const { 139 auto it = node_to_context_.find(node); 140 if (it == node_to_context_.end()) { 141 return nullptr; 142 } 143 return it->second.get(); 144 } 145 146 // Getters and setters for graph_def_version_. graph_def_version()147 int32 graph_def_version() const { return graph_def_version_; } set_graph_def_version(int32 version)148 void set_graph_def_version(int32 version) { graph_def_version_ = version; } 149 set_require_shape_inference_fns(bool require_shape_inference_fns)150 void set_require_shape_inference_fns(bool require_shape_inference_fns) { 151 require_shape_inference_fns_ = require_shape_inference_fns; 152 } set_disable_constant_propagation(bool disable)153 void set_disable_constant_propagation(bool disable) { 154 disable_constant_propagation_ = disable; 155 } 156 157 // Set function library to enable function shape inference. 158 // Without function library, function inference always yields unknown shapes. 159 // With this enabled, shape inference can take more time since it descends 160 // into all function calls. It doesn't do inference once for each function 161 // definition, but once for each function call. 162 // The function library must outlive the shape refiner. set_function_library_for_shape_inference(const tensorflow::FunctionLibraryDefinition * lib)163 void set_function_library_for_shape_inference( 164 const tensorflow::FunctionLibraryDefinition* lib) { 165 function_library_ = lib; 166 } 167 function_shape_inference_supported()168 bool function_shape_inference_supported() const { 169 return function_library_ != nullptr; 170 } 171 172 // Call this to keep nested shapes information for user-defined functions: 173 // nested inferences will be available on the ExtendedInferenceContext for 174 // each function node, forming a tree of shape inferences corresponding to the 175 // tree of nested function calls. By default this setting is disabled, and 176 // only the shapes for the top-level function node will be reported on the 177 // InferenceContext for each function node, to reduce memory usage. 178 // 179 // This flag has no effect when the function inference is not enabled via 180 // set_function_library_for_shape_inference. set_keep_nested_shape_inferences()181 void set_keep_nested_shape_inferences() { 182 keep_nested_shape_inferences_ = true; 183 } 184 185 private: 186 friend class ShapeRefinerTest; 187 friend class ::tensorflow::grappler::GraphProperties; 188 189 // Returns true if the ranks and all dimensions of <s0> and <s1> are either 190 // equal in value or both unknown. 191 static bool SameDefinedShape(shape_inference::InferenceContext* c, 192 shape_inference::ShapeHandle s0, 193 shape_inference::ShapeHandle s1); 194 195 // Returns true if the shapes and types stored in <*existing> are identical in 196 // value to the shapes and types in <*updated>. 197 static bool IsUpdatedShapesOrTypes( 198 shape_inference::InferenceContext* c, 199 const std::vector<shape_inference::ShapeAndType>& existing, 200 const std::vector<shape_inference::ShapeAndType>& updated); 201 202 // Performs shape inference for the given function_def within the 203 // given outer_context. Internally it instantiates the function as a graph 204 // and runs shape inference recursively on it with the input shapes provided 205 // by the outer_context. 206 // 207 // Returns an error if: 208 // - number of inputs/outputs on outer_context doesn't match the function_def 209 // 210 // On success: 211 // - outer_context will contain output shapes inferred from input shapes 212 // - outer_context will contain nested inferences collection, iff 213 // keep_nested_shapes is true 214 Status InferShapesForFunction(const tensorflow::FunctionDef* function_def, 215 bool keep_nested_shapes, 216 ExtendedInferenceContext* outer_context); 217 218 // Tries to infer tensor output based on the input shapes of the node. In some 219 // cases, the shapes of the inputs are sufficient for inferring the contents 220 // of the output tensor. For example, a Shape op with fully defined input 221 // shapes can have its output tensor inferred. 222 Status TryToInferTensorOutputFromInputShapes(const Edge* edge, Tensor* output, 223 bool* success); 224 225 // Extracts the subgraph ending at 'node' that is statically 226 // computable and inserts into 'out_graph'. If statically computable, 227 // 'is_constant_graph' will be true. 228 Status ExtractConstantSubgraph( 229 Node* node, Graph* out_graph, bool* is_constant_graph, 230 std::vector<std::pair<string, Tensor>>* const_inputs) TF_MUST_USE_RESULT; 231 232 Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx, 233 bool* evaluated, Tensor* result); 234 235 // This function tries to materialize as much information about the 'node''s 236 // dst_idx input as a statically computable shape, and the result may be 237 // partially known, depending on what is statically inferable. 238 // 239 // This is called when node.input[dst_idx] is a tensor that is used to define 240 // the shape of some other tensor (e.g., the second argument to Reshape is a 241 // <shape> tensor, where each element of the shape tensor is a dimension of 242 // the target tensor). It returns in <result> a shape for that input. 243 // 244 // Unlike simply resolving node.input[dst_idx] to a constant and then 245 // converting that to a shape, this function can return a partial shape. This 246 // is useful for cases where the shape tensor is only partially defined, such 247 // as with calls for: reshape(x, shape(y)) where shape(y) is partially 248 // defined. 249 // 250 // The implementation has op implementations for ops commonly called on shape 251 // tensors, and the implementations are specialized to shape tensors (namely, 252 // the output is a vector). 253 // 254 // <target_context> is used when creating new DimensionHandle and ShapeHandle 255 // objects. 256 Status ConstantPartialShape(shape_inference::InferenceContext* target_context, 257 const Node* node, int dst_idx, 258 shape_inference::ShapeHandle* result); 259 260 Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, 261 ExtendedInferenceContext* ec); 262 263 int32 graph_def_version_; 264 const OpRegistryInterface* const ops_registry_; 265 266 // The lifetime of the tensors are bound to the runner, so it should be the 267 // deleted after the tensors. 268 GraphRunner graph_runner_; 269 270 // Stores a map from a node to its ExtendedInferenceContext. 271 std::unordered_map<const Node*, std::unique_ptr<ExtendedInferenceContext>> 272 node_to_context_; 273 274 // Holds a cache from 'tensor name' to the tensor that is 275 // evaluatable as a constant expression. This reduces repeated 276 // execution of the entire constant subgraph as a graph is being 277 // built up. This could be changed to some kind of size-based LRU 278 // cache to avoid consuming too much memory, if that eventually 279 // becomes a concern. 280 // 281 // Only tensors less than 1KiB are currently stored in the cache. 282 static constexpr int64 kMaxTensorSize = 1024; 283 std::unordered_map<string, Tensor> const_tensor_map_; 284 285 bool require_shape_inference_fns_ = true; 286 bool disable_constant_propagation_ = false; 287 288 // Function library is optional, but has to be set to enable function 289 // shape inference. 290 const tensorflow::FunctionLibraryDefinition* function_library_ = nullptr; 291 292 // Determines whether to keep the nested shape inference info for user- 293 // defined functions. By default that info is discarded to save memory. 294 bool keep_nested_shape_inferences_ = false; 295 296 // Cache the graph corresponding to each functin definition for which shapes 297 // are refined. 298 std::unordered_map<const FunctionDef*, std::unique_ptr<const Graph>> 299 functions_; 300 301 TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner); 302 }; 303 304 } // namespace tensorflow 305 306 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ 307