1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ 17 18 #include <vector> 19 20 #include "tensorflow/core/common_runtime/graph_runner.h" 21 #include "tensorflow/core/framework/function.pb.h" 22 #include "tensorflow/core/framework/shape_inference.h" 23 #include "tensorflow/core/graph/graph.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/platform/macros.h" 26 27 namespace tensorflow { 28 namespace grappler { 29 class GraphProperties; 30 } 31 32 // This class stores extra inference information in addition to 33 // InferenceContext, such as node input and output types. 34 class ExtendedInferenceContext { 35 public: ExtendedInferenceContext(std::unique_ptr<shape_inference::InferenceContext> ic,const Node * node)36 ExtendedInferenceContext( 37 std::unique_ptr<shape_inference::InferenceContext> ic, const Node* node) 38 : inference_context_(std::move(ic)), op_(node->name()) { 39 input_types_.reserve(node->num_inputs()); 40 for (int i = 0; i < node->num_inputs(); i++) { 41 input_types_.push_back(node->input_type(i)); 42 } 43 output_types_.reserve(node->num_outputs()); 44 for (int i = 0; i < node->num_outputs(); i++) { 45 output_types_.push_back(node->output_type(i)); 46 } 47 } 48 input_type(int64 idx)49 DataType input_type(int64 idx) const { return input_types_[idx]; } output_type(int64 idx)50 DataType output_type(int64 idx) const { return output_types_[idx]; } 51 get_context()52 shape_inference::InferenceContext* get_context() { 53 return inference_context_.get(); 54 } 55 op()56 std::string op() const { return op_; } 57 58 private: 59 std::unique_ptr<shape_inference::InferenceContext> inference_context_; 60 std::string op_; 61 std::vector<DataType> input_types_; 62 std::vector<DataType> output_types_; 63 64 TF_DISALLOW_COPY_AND_ASSIGN(ExtendedInferenceContext); 65 }; 66 67 // ShapeRefiner performs shape inference for TensorFlow Graphs. It is 68 // responsible for instantiating InferenceContext objects for each 69 // Node in the Graph, and providing/storing the 'input_tensor' Tensors 70 // used by Shape Inference functions, when available at graph 71 // construction time. 72 class ShapeRefiner { 73 public: 74 ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops); 75 76 // Same as ShapeRefiner(versions.producer(), ops) 77 ShapeRefiner(const VersionDef& versions, const OpRegistryInterface* ops); 78 79 ~ShapeRefiner(); 80 81 // Performs validation of 'node' and runs 'node's shape function, 82 // storing its shape outputs. 83 // 84 // All inputs of 'node' must be added to ShapeRefiner prior to 85 // adding 'node'. 86 // 87 // Returns an error if: 88 // - the shape function for 'node' was not registered. 89 // - 'node' was added before its inputs. 90 // - The shape inference function returns an error. 91 Status AddNode(const Node* node); 92 93 // Sets 'node's 'output_port' output to have shape 'shape'. 94 // 95 // Returns an error if 'node' was not previously added to this 96 // object, if 'output_port' is invalid, or if 'shape' is 97 // not compatible with the existing shape of the output. 98 Status SetShape(const Node* node, int output_port, 99 shape_inference::ShapeHandle shape); 100 101 // Update the input shapes of node in case the shapes of the fan-ins of 'node' 102 // have themselves been modified (For example, in case of incremental shape 103 // refinement). If 'relax' is true, a new shape with the broadest set of 104 // information will be set as the new input (see InferenceContext::RelaxInput 105 // for full details and examples). Sets refined to true if any shapes have 106 // changed (in their string representations). Note that shapes may have been 107 // updated to newer versions (but with identical string representations) even 108 // if <*refined> is set to false. 109 Status UpdateNode(const Node* node, bool relax, bool* refined); 110 111 // Returns the InferenceContext for 'node', if present. GetContext(const Node * node)112 shape_inference::InferenceContext* GetContext(const Node* node) const { 113 auto it = node_to_context_.find(node); 114 if (it == node_to_context_.end()) { 115 return nullptr; 116 } 117 return it->second->get_context(); 118 } 119 120 // Returns the ExtendedInferenceContext for 'node', if present. GetExtendedContext(const Node * node)121 ExtendedInferenceContext* GetExtendedContext(const Node* node) const { 122 auto it = node_to_context_.find(node); 123 if (it == node_to_context_.end()) { 124 return nullptr; 125 } 126 return it->second.get(); 127 } 128 129 // Getters and setters for graph_def_version_. graph_def_version()130 int32 graph_def_version() const { return graph_def_version_; } set_graph_def_version(int32 version)131 void set_graph_def_version(int32 version) { graph_def_version_ = version; } 132 set_require_shape_inference_fns(bool require_shape_inference_fns)133 void set_require_shape_inference_fns(bool require_shape_inference_fns) { 134 require_shape_inference_fns_ = require_shape_inference_fns; 135 } set_disable_constant_propagation(bool disable)136 void set_disable_constant_propagation(bool disable) { 137 disable_constant_propagation_ = disable; 138 } 139 140 // Set function library to enable function shape inference. 141 // Without function library, function inference always yields unknown shapes. 142 // With this enabled, shape inference can take more time since it descends 143 // into all function calls. It doesn't do inference once for each function 144 // definition, but once for each function call. 145 // The function library must outlive the shape refiner. set_function_library_for_shape_inference(const tensorflow::FunctionLibraryDefinition * lib)146 void set_function_library_for_shape_inference( 147 const tensorflow::FunctionLibraryDefinition* lib) { 148 function_library_ = lib; 149 } 150 function_shape_inference_supported()151 bool function_shape_inference_supported() const { 152 return function_library_ != nullptr; 153 } 154 155 private: 156 friend class ShapeRefinerTest; 157 friend class ::tensorflow::grappler::GraphProperties; 158 159 // Returns true if the ranks and all dimensions of <s0> and <s1> are either 160 // equal in value or both unknown. 161 static bool SameDefinedShape(shape_inference::InferenceContext* c, 162 shape_inference::ShapeHandle s0, 163 shape_inference::ShapeHandle s1); 164 165 // Returns true if the shapes and types stored in <*existing> are identical in 166 // value to the shapes and types in <*updated>. 167 static bool IsUpdatedShapesOrTypes( 168 shape_inference::InferenceContext* c, 169 const std::vector<shape_inference::ShapeAndType>& existing, 170 const std::vector<shape_inference::ShapeAndType>& updated); 171 172 // Performs shape inference for the given function_def within the 173 // given outer_context. Internally it instantiates the function as a graph 174 // and runs shape inference recursively on it with the input shapes provided 175 // by the outer_context. 176 // 177 // Returns an error if: 178 // - number of inputs/outputs on outer_context doesn't match the function_def 179 // 180 // On success: 181 // - outer_context will contain output shapes inferred from input shapes 182 Status InferShapesForFunction(const FunctionDef* function_def, 183 AttrSlice attributes, 184 ExtendedInferenceContext* outer_context); 185 186 // Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge 187 // value can be evaluated, 'evaluated' is set to true and the value returned 188 // in 'result'. Otherwise 'evaluated' is set to false. 189 Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx, 190 bool* evaluated, Tensor* result); 191 192 // Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input 193 // tensors. The caller is responsible for checking that the specified edge is 194 // scalar and int32 or int64. 195 Status EvaluateConstantIntScalarEdge(const Node* node, int dst_idx, 196 bool* evaluated, int64* result); 197 198 // This function tries to materialize as much information about the 'node''s 199 // dst_idx input as a statically computable shape, and the result may be 200 // partially known, depending on what is statically inferable. 201 // 202 // This is called when node.input[dst_idx] is a tensor that is used to define 203 // the shape of some other tensor (e.g., the second argument to Reshape is a 204 // <shape> tensor, where each element of the shape tensor is a dimension of 205 // the target tensor). It returns in <result> a shape for that input. 206 // 207 // Unlike simply resolving node.input[dst_idx] to a constant and then 208 // converting that to a shape, this function can return a partial shape. This 209 // is useful for cases where the shape tensor is only partially defined, such 210 // as with calls for: reshape(x, shape(y)) where shape(y) is partially 211 // defined. 212 // 213 // The implementation has op implementations for ops commonly called on shape 214 // tensors, and the implementations are specialized to shape tensors (namely, 215 // the output is a vector). 216 // 217 // <target_context> is used when creating new DimensionHandle and ShapeHandle 218 // objects. 219 Status ConstantPartialShape(shape_inference::InferenceContext* target_context, 220 const Node* node, int dst_idx, 221 shape_inference::ShapeHandle* result); 222 223 // Implementation of ConstantPartialShape for StridedSlice nodes. 224 Status PartialStridedSliceShape(Node* slice_node, 225 shape_inference::InferenceContext* ctx, 226 shape_inference::ShapeHandle* result); 227 228 Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, 229 ExtendedInferenceContext* ec); 230 231 int32 graph_def_version_; 232 const OpRegistryInterface* const ops_registry_; 233 234 // The lifetime of the tensors are bound to the runner, so it should be the 235 // deleted after the tensors. 236 GraphRunner graph_runner_; 237 238 // Stores a map from a node to its ExtendedInferenceContext. 239 std::unordered_map<const Node*, std::unique_ptr<ExtendedInferenceContext>> 240 node_to_context_; 241 242 // Holds a cache from 'tensor name' to the tensor that is 243 // evaluatable as a constant expression. This reduces repeated 244 // execution of the entire constant subgraph as a graph is being 245 // built up. This could be changed to some kind of size-based LRU 246 // cache to avoid consuming too much memory, if that eventually 247 // becomes a concern. 248 // 249 // Only tensors less than 1KiB are currently stored in the cache. 250 static constexpr int64 kMaxTensorSize = 1024; 251 std::unordered_map<string, Tensor> const_tensor_map_; 252 253 bool require_shape_inference_fns_ = true; 254 bool disable_constant_propagation_ = false; 255 256 // Function library is optional, but has to be set to enable function 257 // shape inference. 258 const tensorflow::FunctionLibraryDefinition* function_library_ = nullptr; 259 260 // Cache the graph corresponding to each functin definition for which shapes 261 // are refined. 262 std::unordered_map<const FunctionDef*, std::unique_ptr<const Graph>> 263 functions_; 264 265 TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner); 266 }; 267 268 } // namespace tensorflow 269 270 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ 271