1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ 17 18 #include <vector> 19 20 #include "absl/container/flat_hash_map.h" 21 #include "tensorflow/core/common_runtime/graph_runner.h" 22 #include "tensorflow/core/framework/function.pb.h" 23 #include "tensorflow/core/framework/shape_inference.h" 24 #include "tensorflow/core/graph/graph.h" 25 #include "tensorflow/core/lib/core/status.h" 26 #include "tensorflow/core/platform/macros.h" 27 28 namespace tensorflow { 29 namespace grappler { 30 class GraphProperties; 31 } 32 33 // This class stores extra inference information in addition to 34 // InferenceContext, such as node input and output types. 35 class ExtendedInferenceContext { 36 public: ExtendedInferenceContext(std::unique_ptr<shape_inference::InferenceContext> ic,const Node * node)37 ExtendedInferenceContext( 38 std::unique_ptr<shape_inference::InferenceContext> ic, const Node* node) 39 : inference_context_(std::move(ic)), op_(node->name()) { 40 input_types_.reserve(node->num_inputs()); 41 for (int i = 0; i < node->num_inputs(); i++) { 42 input_types_.push_back(node->input_type(i)); 43 } 44 output_types_.reserve(node->num_outputs()); 45 for (int i = 0; i < node->num_outputs(); i++) { 46 output_types_.push_back(node->output_type(i)); 47 } 48 } 49 input_type(int64_t idx)50 DataType input_type(int64_t idx) const { return input_types_[idx]; } output_type(int64_t idx)51 DataType output_type(int64_t idx) const { return output_types_[idx]; } 52 get_context()53 shape_inference::InferenceContext* get_context() { 54 return inference_context_.get(); 55 } 56 op()57 std::string op() const { return op_; } 58 59 private: 60 std::unique_ptr<shape_inference::InferenceContext> inference_context_; 61 std::string op_; 62 std::vector<DataType> input_types_; 63 std::vector<DataType> output_types_; 64 65 TF_DISALLOW_COPY_AND_ASSIGN(ExtendedInferenceContext); 66 }; 67 68 // ShapeRefiner performs shape inference for TensorFlow Graphs. It is 69 // responsible for instantiating InferenceContext objects for each 70 // Node in the Graph, and providing/storing the 'input_tensor' Tensors 71 // used by Shape Inference functions, when available at graph 72 // construction time. 73 class ShapeRefiner { 74 public: 75 ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops); 76 77 // Same as ShapeRefiner(versions.producer(), ops) 78 ShapeRefiner(const VersionDef& versions, const OpRegistryInterface* ops); 79 80 ~ShapeRefiner(); 81 82 // Performs validation of 'node' and runs 'node's shape function, 83 // storing its shape outputs. 84 // 85 // All inputs of 'node' must be added to ShapeRefiner prior to 86 // adding 'node'. 87 // 88 // Returns an error if: 89 // - the shape function for 'node' was not registered. 90 // - 'node' was added before its inputs. 91 // - The shape inference function returns an error. 92 Status AddNode(const Node* node); 93 94 // Sets 'node's 'output_port' output to have shape 'shape'. 95 // 96 // Returns an error if 'node' was not previously added to this 97 // object, if 'output_port' is invalid, or if 'shape' is 98 // not compatible with the existing shape of the output. 99 Status SetShape(const Node* node, int output_port, 100 shape_inference::ShapeHandle shape); 101 102 // Update the input shapes of node in case the shapes of the fan-ins of 'node' 103 // have themselves been modified (For example, in case of incremental shape 104 // refinement). If 'relax' is true, a new shape with the broadest set of 105 // information will be set as the new input (see InferenceContext::RelaxInput 106 // for full details and examples). Sets refined to true if any shapes have 107 // changed (in their string representations). Note that shapes may have been 108 // updated to newer versions (but with identical string representations) even 109 // if <*refined> is set to false. 110 Status UpdateNode(const Node* node, bool relax, bool* refined); 111 112 // Returns the InferenceContext for 'node', if present. GetContext(const Node * node)113 shape_inference::InferenceContext* GetContext(const Node* node) const { 114 auto it = node_to_context_.find(node); 115 if (it == node_to_context_.end()) { 116 return nullptr; 117 } 118 return it->second->get_context(); 119 } 120 121 // Returns the ExtendedInferenceContext for 'node', if present. GetExtendedContext(const Node * node)122 ExtendedInferenceContext* GetExtendedContext(const Node* node) const { 123 auto it = node_to_context_.find(node); 124 if (it == node_to_context_.end()) { 125 return nullptr; 126 } 127 return it->second.get(); 128 } 129 130 // Getters and setters for graph_def_version_. graph_def_version()131 int32 graph_def_version() const { return graph_def_version_; } set_graph_def_version(int32_t version)132 void set_graph_def_version(int32_t version) { graph_def_version_ = version; } 133 set_require_shape_inference_fns(bool require_shape_inference_fns)134 void set_require_shape_inference_fns(bool require_shape_inference_fns) { 135 require_shape_inference_fns_ = require_shape_inference_fns; 136 } set_disable_constant_propagation(bool disable)137 void set_disable_constant_propagation(bool disable) { 138 disable_constant_propagation_ = disable; 139 } 140 141 // Set function library to enable function shape inference. 142 // Without function library, function inference always yields unknown shapes. 143 // With this enabled, shape inference can take more time since it descends 144 // into all function calls. It doesn't do inference once for each function 145 // definition, but once for each function call. 146 // The function library must outlive the shape refiner. set_function_library_for_shape_inference(const tensorflow::FunctionLibraryDefinition * lib)147 void set_function_library_for_shape_inference( 148 const tensorflow::FunctionLibraryDefinition* lib) { 149 function_library_ = lib; 150 } 151 function_shape_inference_supported()152 bool function_shape_inference_supported() const { 153 return function_library_ != nullptr; 154 } 155 156 private: 157 friend class ShapeRefinerTest; 158 friend class ::tensorflow::grappler::GraphProperties; 159 160 // Returns true if the ranks and all dimensions of <s0> and <s1> are either 161 // equal in value or both unknown. 162 static bool SameDefinedShape(shape_inference::InferenceContext* c, 163 shape_inference::ShapeHandle s0, 164 shape_inference::ShapeHandle s1); 165 166 // Returns true if the shapes and types stored in <*existing> are identical in 167 // value to the shapes and types in <*updated>. 168 static bool IsUpdatedShapesOrTypes( 169 shape_inference::InferenceContext* c, 170 const std::vector<shape_inference::ShapeAndType>& existing, 171 const std::vector<shape_inference::ShapeAndType>& updated); 172 173 // Performs shape inference for the given function_def within the 174 // given outer_context. Internally it instantiates the function as a graph 175 // and runs shape inference recursively on it with the input shapes provided 176 // by the outer_context. 177 // 178 // Returns an error if: 179 // - number of inputs/outputs on outer_context doesn't match the function_def 180 // 181 // On success: 182 // - outer_context will contain output shapes inferred from input shapes 183 Status InferShapesForFunction(const FunctionDef* function_def, 184 AttrSlice attributes, 185 ExtendedInferenceContext* outer_context); 186 187 // Performs shape inference for a node inside a function. 188 // 189 // 'outer_context' is the 'InferenceContext' for the function's call op. 190 Status InferShapesForFunctionSubNode( 191 const Node* node, shape_inference::InferenceContext* outer_context); 192 193 // Performs validation of 'node' and runs 'node's shape function, 194 // storing its shape outputs. 195 // 196 // All inputs of 'node' must be added to ShapeRefiner prior to 197 // adding 'node'. 198 // 199 // Optionally, if 'node' is in a nested function, the 'InferenceContext' for 200 // the call op of the function can be passed as 'outer_context' (pass nullptr 201 // otherwise). This gets used to perform constant propagation across Arg nodes 202 // by requesting the constant of value of the incoming tensor from the 203 // 'outer_context'. 204 // 205 // Returns an error if: 206 // - the shape function for 'node' was not registered. 207 // - 'node' was added before its inputs. 208 // - The shape inference function returns an error. 209 Status AddNodeInternal(const Node* node, 210 shape_inference::InferenceContext* outer_context); 211 212 // Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge 213 // value can be evaluated, 'evaluated' is set to true and the value returned 214 // in 'result'. Otherwise 'evaluated' is set to false. 215 // 216 // Optionally, if 'node' is in a nested function, the 'InferenceContext' for 217 // the call op of the function can be passed as 'outer_context' (pass nullptr 218 // otherwise). This gets used to perform constant propagation across Arg nodes 219 // by requesting the constant of value of the incoming tensor from the 220 // 'outer_context'. 221 Status EvaluateConstantTensorForEdge( 222 const Node* node, int dst_idx, bool* evaluated, Tensor* result, 223 shape_inference::InferenceContext* outer_context); 224 225 // Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input 226 // tensors. The caller is responsible for checking that the specified edge is 227 // scalar and int32 or int64. 228 // 229 // Optionally, if 'node' is in a nested function, the 'InferenceContext' for 230 // the call op of the function can be passed as 'outer_context' (pass nullptr 231 // otherwise). This gets used to perform constant propagation across Arg nodes 232 // by requesting the constant of value of the incoming tensor from the 233 // 'outer_context'. 234 Status EvaluateConstantIntScalarEdge( 235 const Node* node, int dst_idx, bool* evaluated, int64* result, 236 shape_inference::InferenceContext* outer_context); 237 238 // This function tries to materialize as much information about the 'node''s 239 // dst_idx input as a statically computable shape, and the result may be 240 // partially known, depending on what is statically inferable. 241 // 242 // This is called when node.input[dst_idx] is a tensor that is used to define 243 // the shape of some other tensor (e.g., the second argument to Reshape is a 244 // <shape> tensor, where each element of the shape tensor is a dimension of 245 // the target tensor). It returns in <result> a shape for that input. 246 // 247 // Unlike simply resolving node.input[dst_idx] to a constant and then 248 // converting that to a shape, this function can return a partial shape. This 249 // is useful for cases where the shape tensor is only partially defined, such 250 // as with calls for: reshape(x, shape(y)) where shape(y) is partially 251 // defined. 252 // 253 // The implementation has op implementations for ops commonly called on shape 254 // tensors, and the implementations are specialized to shape tensors (namely, 255 // the output is a vector). 256 // 257 // <target_context> is used when creating new DimensionHandle and ShapeHandle 258 // objects. 259 // 260 // Optionally, if 'node' is in a nested function, the 'InferenceContext' for 261 // the call op of the function can be passed as 'outer_context' (pass nullptr 262 // otherwise). This gets used to perform constant propagation across Arg nodes 263 // by requesting the constant of value of the incoming tensor from the 264 // 'outer_context'. 265 Status ConstantPartialShape(shape_inference::InferenceContext* target_context, 266 const Node* node, int dst_idx, 267 shape_inference::ShapeHandle* result, 268 shape_inference::InferenceContext* outer_context); 269 270 // Implementation of ConstantPartialShape for StridedSlice nodes. 271 // 272 // Optionally, if 'node' is in a nested function, the 'InferenceContext' for 273 // the call op of the function can be passed as 'outer_context' (pass nullptr 274 // otherwise). This gets used to perform constant propagation across Arg nodes 275 // by requesting the constant of value of the incoming tensor from the 276 // 'outer_context'. 277 Status PartialStridedSliceShape( 278 Node* slice_node, shape_inference::InferenceContext* ctx, 279 shape_inference::ShapeHandle* result, 280 shape_inference::InferenceContext* outer_context); 281 282 // Runs the shape function registered for the node's op type. 283 // 284 // Optionally, if 'node' is in a nested function, the 'InferenceContext' for 285 // the call op of the function can be passed as 'outer_context' (pass nullptr 286 // otherwise). This gets used to perform constant propagation across Arg nodes 287 // by requesting the constant of value of the incoming tensor from the 288 // 'outer_context'. 289 Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, 290 ExtendedInferenceContext* ec, 291 shape_inference::InferenceContext* outer_context = nullptr); 292 293 int32 graph_def_version_; 294 const OpRegistryInterface* const ops_registry_; 295 296 // The lifetime of the tensors are bound to the runner, so it should be the 297 // deleted after the tensors. 298 GraphRunner graph_runner_; 299 300 // Stores a map from a node to its ExtendedInferenceContext. 301 absl::flat_hash_map<const Node*, std::unique_ptr<ExtendedInferenceContext>, 302 hash<const Node*>> 303 node_to_context_; 304 305 // Holds a cache from 'tensor name' to the tensor that is 306 // evaluatable as a constant expression. This reduces repeated 307 // execution of the entire constant subgraph as a graph is being 308 // built up. This could be changed to some kind of size-based LRU 309 // cache to avoid consuming too much memory, if that eventually 310 // becomes a concern. 311 // 312 // Only tensors less than 1KiB are currently stored in the cache. 313 static constexpr int64_t kMaxTensorSize = 1024; 314 std::unordered_map<string, Tensor> const_tensor_map_; 315 316 bool require_shape_inference_fns_ = true; 317 bool disable_constant_propagation_ = false; 318 319 // Function library is optional, but has to be set to enable function 320 // shape inference. 321 const tensorflow::FunctionLibraryDefinition* function_library_ = nullptr; 322 323 // Cache the graph corresponding to each function definition for which shapes 324 // are refined. 325 absl::flat_hash_map<const FunctionDef*, std::unique_ptr<const Graph>, 326 hash<const FunctionDef*>> 327 functions_; 328 329 TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner); 330 }; 331 332 } // namespace tensorflow 333 334 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ 335