• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
17 
18 #include <vector>
19 
20 #include "tensorflow/core/common_runtime/graph_runner.h"
21 #include "tensorflow/core/framework/function.pb.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/macros.h"
26 
27 namespace tensorflow {
28 namespace grappler {
29 class GraphProperties;
30 }
31 
32 // This class stores extra inference information in addition to
33 // InferenceContext, such as inference tree for user-defined functions and node
34 // input and output types.
35 class ExtendedInferenceContext {
36  public:
ExtendedInferenceContext(std::unique_ptr<shape_inference::InferenceContext> ic,const Node * node)37   ExtendedInferenceContext(
38       std::unique_ptr<shape_inference::InferenceContext> ic, const Node* node)
39       : inference_context_(std::move(ic)) {
40     input_types_.reserve(node->num_inputs());
41     for (int i = 0; i < node->num_inputs(); i++) {
42       input_types_.push_back(node->input_type(i));
43     }
44     output_types_.reserve(node->num_outputs());
45     for (int i = 0; i < node->num_outputs(); i++) {
46       output_types_.push_back(node->output_type(i));
47     }
48   }
49 
50   const std::unordered_map<string, std::unique_ptr<ExtendedInferenceContext>>&
nested_inferences()51   nested_inferences() const {
52     return nested_inferences_;
53   }
input_type(int64 idx)54   DataType input_type(int64 idx) const { return input_types_[idx]; }
output_type(int64 idx)55   DataType output_type(int64 idx) const { return output_types_[idx]; }
56 
get_context()57   shape_inference::InferenceContext* get_context() {
58     return inference_context_.get();
59   }
60 
61   // Sets nested inference info.
62   // For composite ops (user-defined functions) only.
63   // Inference for trivial ops must not call this setter.
set_nested_inferences(std::unordered_map<string,std::unique_ptr<ExtendedInferenceContext>> inferences)64   void set_nested_inferences(
65       std::unordered_map<string, std::unique_ptr<ExtendedInferenceContext>>
66           inferences) {
67     nested_inferences_ = std::move(inferences);
68   }
69 
70  private:
71   std::unique_ptr<shape_inference::InferenceContext> inference_context_;
72   std::vector<DataType> input_types_;
73   std::vector<DataType> output_types_;
74 
75   // Nested inferences for composite ops (user-defined functions).
76   // Mapping key is nested node name.
77   // For trivial ops this map must be empty.
78   std::unordered_map<string, std::unique_ptr<ExtendedInferenceContext>>
79       nested_inferences_;
80 
81   TF_DISALLOW_COPY_AND_ASSIGN(ExtendedInferenceContext);
82 };
83 
84 // ShapeRefiner performs shape inference for TensorFlow Graphs.  It is
85 // responsible for instantiating InferenceContext objects for each
86 // Node in the Graph, and providing/storing the 'input_tensor' Tensors
87 // used by Shape Inference functions, when available at graph
88 // construction time.
89 class ShapeRefiner {
90  public:
91   ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops);
92 
93   // Same as ShapeRefiner(versions.producer(), ops)
94   ShapeRefiner(const VersionDef& versions, const OpRegistryInterface* ops);
95 
96   ~ShapeRefiner();
97 
98   // Performs validation of 'node' and runs 'node's shape function,
99   // storing its shape outputs.
100   //
101   // All inputs of 'node' must be added to ShapeRefiner prior to
102   // adding 'node'.
103   //
104   // Returns an error if:
105   //  - the shape function for 'node' was not registered.
106   //  - 'node' was added before its inputs.
107   //  - The shape inference function returns an error.
108   Status AddNode(const Node* node);
109 
110   // Sets 'node's 'output_port' output to have shape 'shape'.
111   //
112   // Returns an error if 'node' was not previously added to this
113   // object, if 'output_port' is invalid, or if 'shape' is
114   // not compatible with the existing shape of the output.
115   Status SetShape(const Node* node, int output_port,
116                   shape_inference::ShapeHandle shape);
117 
118   // Update the input shapes of node in case the shapes of the fan-ins of 'node'
119   // have themselves been modified (For example, in case of incremental shape
120   // refinement). If 'relax' is true, a new shape with the broadest set of
121   // information will be set as the new input (see InferenceContext::RelaxInput
122   // for full details and examples). Sets refined to true if any shapes have
123   // changed (in their string representations). Note that shapes may have been
124   // updated to newer versions (but with identical string representations) even
125   // if <*refined> is set to false.
126   Status UpdateNode(const Node* node, bool relax, bool* refined);
127 
128   // Returns the InferenceContext for 'node', if present.
GetContext(const Node * node)129   shape_inference::InferenceContext* GetContext(const Node* node) const {
130     auto it = node_to_context_.find(node);
131     if (it == node_to_context_.end()) {
132       return nullptr;
133     }
134     return it->second->get_context();
135   }
136 
137   // Returns the ExtendedInferenceContext for 'node', if present.
GetExtendedContext(const Node * node)138   ExtendedInferenceContext* GetExtendedContext(const Node* node) const {
139     auto it = node_to_context_.find(node);
140     if (it == node_to_context_.end()) {
141       return nullptr;
142     }
143     return it->second.get();
144   }
145 
146   // Getters and setters for graph_def_version_.
graph_def_version()147   int32 graph_def_version() const { return graph_def_version_; }
set_graph_def_version(int32 version)148   void set_graph_def_version(int32 version) { graph_def_version_ = version; }
149 
set_require_shape_inference_fns(bool require_shape_inference_fns)150   void set_require_shape_inference_fns(bool require_shape_inference_fns) {
151     require_shape_inference_fns_ = require_shape_inference_fns;
152   }
set_disable_constant_propagation(bool disable)153   void set_disable_constant_propagation(bool disable) {
154     disable_constant_propagation_ = disable;
155   }
156 
157   // Set function library to enable function shape inference.
158   // Without function library, function inference always yields unknown shapes.
159   // With this enabled, shape inference can take more time since it descends
160   // into all function calls. It doesn't do inference once for each function
161   // definition, but once for each function call.
162   // The function library must outlive the shape refiner.
set_function_library_for_shape_inference(const tensorflow::FunctionLibraryDefinition * lib)163   void set_function_library_for_shape_inference(
164       const tensorflow::FunctionLibraryDefinition* lib) {
165     function_library_ = lib;
166   }
167 
function_shape_inference_supported()168   bool function_shape_inference_supported() const {
169     return function_library_ != nullptr;
170   }
171 
172   // Call this to keep nested shapes information for user-defined functions:
173   // nested inferences will be available on the ExtendedInferenceContext for
174   // each function node, forming a tree of shape inferences corresponding to the
175   // tree of nested function calls. By default this setting is disabled, and
176   // only the shapes for the top-level function node will be reported on the
177   // InferenceContext for each function node, to reduce memory usage.
178   //
179   // This flag has no effect when the function inference is not enabled via
180   // set_function_library_for_shape_inference.
set_keep_nested_shape_inferences()181   void set_keep_nested_shape_inferences() {
182     keep_nested_shape_inferences_ = true;
183   }
184 
185  private:
186   friend class ShapeRefinerTest;
187   friend class ::tensorflow::grappler::GraphProperties;
188 
189   // Returns true if the ranks and all dimensions of <s0> and <s1> are either
190   // equal in value or both unknown.
191   static bool SameDefinedShape(shape_inference::InferenceContext* c,
192                                shape_inference::ShapeHandle s0,
193                                shape_inference::ShapeHandle s1);
194 
195   // Returns true if the shapes and types stored in <*existing> are identical in
196   // value to the shapes and types in <*updated>.
197   static bool IsUpdatedShapesOrTypes(
198       shape_inference::InferenceContext* c,
199       const std::vector<shape_inference::ShapeAndType>& existing,
200       const std::vector<shape_inference::ShapeAndType>& updated);
201 
202   // Performs shape inference for the given function_def within the
203   // given outer_context. Internally it instantiates the function as a graph
204   // and runs shape inference recursively on it with the input shapes provided
205   // by the outer_context.
206   //
207   // Returns an error if:
208   // - number of inputs/outputs on outer_context doesn't match the function_def
209   //
210   // On success:
211   // - outer_context will contain output shapes inferred from input shapes
212   // - outer_context will contain nested inferences collection, iff
213   //   keep_nested_shapes is true
214   Status InferShapesForFunction(const tensorflow::FunctionDef* function_def,
215                                 bool keep_nested_shapes,
216                                 ExtendedInferenceContext* outer_context);
217 
218   // Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge
219   // value can be evaluated, 'evaluated' is set to true and the value returned
220   // in 'result'. Otherwise 'evaluated' is set to false.
221   Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx,
222                                        bool* evaluated, Tensor* result);
223 
224   // Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input
225   // tensors. The caller is responsible for checking that the specified edge is
226   // scalar and int32 or int64.
227   Status EvaluateConstantIntScalarEdge(const Node* node, int dst_idx,
228                                        bool* evaluated, int64* result);
229 
230   // This function tries to materialize as much information about the 'node''s
231   // dst_idx input as a statically computable shape, and the result may be
232   // partially known, depending on what is statically inferable.
233   //
234   // This is called when node.input[dst_idx] is a tensor that is used to define
235   // the shape of some other tensor (e.g., the second argument to Reshape is a
236   // <shape> tensor, where each element of the shape tensor is a dimension of
237   // the target tensor).  It returns in <result> a shape for that input.
238   //
239   // Unlike simply resolving node.input[dst_idx] to a constant and then
240   // converting that to a shape, this function can return a partial shape. This
241   // is useful for cases where the shape tensor is only partially defined, such
242   // as with calls for: reshape(x, shape(y)) where shape(y) is partially
243   // defined.
244   //
245   // The implementation has op implementations for ops commonly called on shape
246   // tensors, and the implementations are specialized to shape tensors (namely,
247   // the output is a vector).
248   //
249   // <target_context> is used when creating new DimensionHandle and ShapeHandle
250   // objects.
251   Status ConstantPartialShape(shape_inference::InferenceContext* target_context,
252                               const Node* node, int dst_idx,
253                               shape_inference::ShapeHandle* result);
254 
255   // Implementation of ConstantPartialShape for StridedSlice nodes.
256   Status PartialStridedSliceShape(Node* slice_node,
257                                   shape_inference::InferenceContext* ctx,
258                                   shape_inference::ShapeHandle* result);
259 
260   Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data,
261                     ExtendedInferenceContext* ec);
262 
263   int32 graph_def_version_;
264   const OpRegistryInterface* const ops_registry_;
265 
266   // The lifetime of the tensors are bound to the runner, so it should be the
267   // deleted after the tensors.
268   GraphRunner graph_runner_;
269 
270   // Stores a map from a node to its ExtendedInferenceContext.
271   std::unordered_map<const Node*, std::unique_ptr<ExtendedInferenceContext>>
272       node_to_context_;
273 
274   // Holds a cache from 'tensor name' to the tensor that is
275   // evaluatable as a constant expression.  This reduces repeated
276   // execution of the entire constant subgraph as a graph is being
277   // built up.  This could be changed to some kind of size-based LRU
278   // cache to avoid consuming too much memory, if that eventually
279   // becomes a concern.
280   //
281   // Only tensors less than 1KiB are currently stored in the cache.
282   static constexpr int64 kMaxTensorSize = 1024;
283   std::unordered_map<string, Tensor> const_tensor_map_;
284 
285   bool require_shape_inference_fns_ = true;
286   bool disable_constant_propagation_ = false;
287 
288   // Function library is optional, but has to be set to enable function
289   // shape inference.
290   const tensorflow::FunctionLibraryDefinition* function_library_ = nullptr;
291 
292   // Determines whether to keep the nested shape inference info for user-
293   // defined functions. By default that info is discarded to save memory.
294   bool keep_nested_shape_inferences_ = false;
295 
296   // Cache the graph corresponding to each functin definition for which shapes
297   // are refined.
298   std::unordered_map<const FunctionDef*, std::unique_ptr<const Graph>>
299       functions_;
300 
301   TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner);
302 };
303 
304 }  // namespace tensorflow
305 
306 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
307