• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
17 
18 #include <vector>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/common_runtime/graph_runner.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/shape_inference.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/macros.h"
27 
28 namespace tensorflow {
29 namespace grappler {
30 class GraphProperties;
31 }
32 
33 // This class stores extra inference information in addition to
34 // InferenceContext, such as node input and output types.
35 class ExtendedInferenceContext {
36  public:
ExtendedInferenceContext(std::unique_ptr<shape_inference::InferenceContext> ic,const Node * node)37   ExtendedInferenceContext(
38       std::unique_ptr<shape_inference::InferenceContext> ic, const Node* node)
39       : inference_context_(std::move(ic)), op_(node->name()) {
40     input_types_.reserve(node->num_inputs());
41     for (int i = 0; i < node->num_inputs(); i++) {
42       input_types_.push_back(node->input_type(i));
43     }
44     output_types_.reserve(node->num_outputs());
45     for (int i = 0; i < node->num_outputs(); i++) {
46       output_types_.push_back(node->output_type(i));
47     }
48   }
49 
input_type(int64_t idx)50   DataType input_type(int64_t idx) const { return input_types_[idx]; }
output_type(int64_t idx)51   DataType output_type(int64_t idx) const { return output_types_[idx]; }
52 
get_context()53   shape_inference::InferenceContext* get_context() {
54     return inference_context_.get();
55   }
56 
op()57   std::string op() const { return op_; }
58 
59  private:
60   std::unique_ptr<shape_inference::InferenceContext> inference_context_;
61   std::string op_;
62   std::vector<DataType> input_types_;
63   std::vector<DataType> output_types_;
64 
65   TF_DISALLOW_COPY_AND_ASSIGN(ExtendedInferenceContext);
66 };
67 
68 // ShapeRefiner performs shape inference for TensorFlow Graphs.  It is
69 // responsible for instantiating InferenceContext objects for each
70 // Node in the Graph, and providing/storing the 'input_tensor' Tensors
71 // used by Shape Inference functions, when available at graph
72 // construction time.
73 class ShapeRefiner {
74  public:
75   ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops);
76 
77   // Same as ShapeRefiner(versions.producer(), ops)
78   ShapeRefiner(const VersionDef& versions, const OpRegistryInterface* ops);
79 
80   ~ShapeRefiner();
81 
82   // Performs validation of 'node' and runs 'node's shape function,
83   // storing its shape outputs.
84   //
85   // All inputs of 'node' must be added to ShapeRefiner prior to
86   // adding 'node'.
87   //
88   // Returns an error if:
89   //  - the shape function for 'node' was not registered.
90   //  - 'node' was added before its inputs.
91   //  - The shape inference function returns an error.
92   Status AddNode(const Node* node);
93 
94   // Sets 'node's 'output_port' output to have shape 'shape'.
95   //
96   // Returns an error if 'node' was not previously added to this
97   // object, if 'output_port' is invalid, or if 'shape' is
98   // not compatible with the existing shape of the output.
99   Status SetShape(const Node* node, int output_port,
100                   shape_inference::ShapeHandle shape);
101 
102   // Update the input shapes of node in case the shapes of the fan-ins of 'node'
103   // have themselves been modified (For example, in case of incremental shape
104   // refinement). If 'relax' is true, a new shape with the broadest set of
105   // information will be set as the new input (see InferenceContext::RelaxInput
106   // for full details and examples). Sets refined to true if any shapes have
107   // changed (in their string representations). Note that shapes may have been
108   // updated to newer versions (but with identical string representations) even
109   // if <*refined> is set to false.
110   Status UpdateNode(const Node* node, bool relax, bool* refined);
111 
112   // Returns the InferenceContext for 'node', if present.
GetContext(const Node * node)113   shape_inference::InferenceContext* GetContext(const Node* node) const {
114     auto it = node_to_context_.find(node);
115     if (it == node_to_context_.end()) {
116       return nullptr;
117     }
118     return it->second->get_context();
119   }
120 
121   // Returns the ExtendedInferenceContext for 'node', if present.
GetExtendedContext(const Node * node)122   ExtendedInferenceContext* GetExtendedContext(const Node* node) const {
123     auto it = node_to_context_.find(node);
124     if (it == node_to_context_.end()) {
125       return nullptr;
126     }
127     return it->second.get();
128   }
129 
130   // Getters and setters for graph_def_version_.
graph_def_version()131   int32 graph_def_version() const { return graph_def_version_; }
set_graph_def_version(int32_t version)132   void set_graph_def_version(int32_t version) { graph_def_version_ = version; }
133 
set_require_shape_inference_fns(bool require_shape_inference_fns)134   void set_require_shape_inference_fns(bool require_shape_inference_fns) {
135     require_shape_inference_fns_ = require_shape_inference_fns;
136   }
set_disable_constant_propagation(bool disable)137   void set_disable_constant_propagation(bool disable) {
138     disable_constant_propagation_ = disable;
139   }
140 
141   // Set function library to enable function shape inference.
142   // Without function library, function inference always yields unknown shapes.
143   // With this enabled, shape inference can take more time since it descends
144   // into all function calls. It doesn't do inference once for each function
145   // definition, but once for each function call.
146   // The function library must outlive the shape refiner.
set_function_library_for_shape_inference(const tensorflow::FunctionLibraryDefinition * lib)147   void set_function_library_for_shape_inference(
148       const tensorflow::FunctionLibraryDefinition* lib) {
149     function_library_ = lib;
150   }
151 
function_shape_inference_supported()152   bool function_shape_inference_supported() const {
153     return function_library_ != nullptr;
154   }
155 
156  private:
157   friend class ShapeRefinerTest;
158   friend class ::tensorflow::grappler::GraphProperties;
159 
160   // Returns true if the ranks and all dimensions of <s0> and <s1> are either
161   // equal in value or both unknown.
162   static bool SameDefinedShape(shape_inference::InferenceContext* c,
163                                shape_inference::ShapeHandle s0,
164                                shape_inference::ShapeHandle s1);
165 
166   // Returns true if the shapes and types stored in <*existing> are identical in
167   // value to the shapes and types in <*updated>.
168   static bool IsUpdatedShapesOrTypes(
169       shape_inference::InferenceContext* c,
170       const std::vector<shape_inference::ShapeAndType>& existing,
171       const std::vector<shape_inference::ShapeAndType>& updated);
172 
173   // Performs shape inference for the given function_def within the
174   // given outer_context. Internally it instantiates the function as a graph
175   // and runs shape inference recursively on it with the input shapes provided
176   // by the outer_context.
177   //
178   // Returns an error if:
179   // - number of inputs/outputs on outer_context doesn't match the function_def
180   //
181   // On success:
182   // - outer_context will contain output shapes inferred from input shapes
183   Status InferShapesForFunction(
184       const FunctionDef* function_def, AttrSlice attributes,
185       shape_inference::InferenceContext* outer_context);
186 
187   // Performs shape inference for a node inside a function.
188   //
189   // 'outer_context' is the 'InferenceContext' for the function's call op.
190   Status InferShapesForFunctionSubNode(
191       const Node* node, shape_inference::InferenceContext* outer_context);
192 
193   // Performs validation of 'node' and runs 'node's shape function,
194   // storing its shape outputs.
195   //
196   // All inputs of 'node' must be added to ShapeRefiner prior to
197   // adding 'node'.
198   //
199   // Optionally, if 'node' is in a nested function, the 'InferenceContext' for
200   // the call op of the function can be passed as 'outer_context' (pass nullptr
201   // otherwise). This gets used to perform constant propagation across Arg nodes
202   // by requesting the constant of value of the incoming tensor from the
203   // 'outer_context'.
204   //
205   // Returns an error if:
206   //  - the shape function for 'node' was not registered.
207   //  - 'node' was added before its inputs.
208   //  - The shape inference function returns an error.
209   Status AddNodeInternal(const Node* node,
210                          shape_inference::InferenceContext* outer_context);
211 
212   // Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge
213   // value can be evaluated, 'evaluated' is set to true and the value returned
214   // in 'result'. Otherwise 'evaluated' is set to false.
215   //
216   // Optionally, if 'node' is in a nested function, the 'InferenceContext' for
217   // the call op of the function can be passed as 'outer_context' (pass nullptr
218   // otherwise). This gets used to perform constant propagation across Arg nodes
219   // by requesting the constant of value of the incoming tensor from the
220   // 'outer_context'.
221   Status EvaluateConstantTensorForEdge(
222       const Node* node, int dst_idx, bool* evaluated, Tensor* result,
223       shape_inference::InferenceContext* outer_context);
224 
225   // Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input
226   // tensors. The caller is responsible for checking that the specified edge is
227   // scalar and int32 or int64.
228   //
229   // Optionally, if 'node' is in a nested function, the 'InferenceContext' for
230   // the call op of the function can be passed as 'outer_context' (pass nullptr
231   // otherwise). This gets used to perform constant propagation across Arg nodes
232   // by requesting the constant of value of the incoming tensor from the
233   // 'outer_context'.
234   Status EvaluateConstantIntScalarEdge(
235       const Node* node, int dst_idx, bool* evaluated, int64_t* result,
236       shape_inference::InferenceContext* outer_context);
237 
238   // This function tries to materialize as much information about the 'node''s
239   // dst_idx input as a statically computable shape, and the result may be
240   // partially known, depending on what is statically inferable.
241   //
242   // This is called when node.input[dst_idx] is a tensor that is used to define
243   // the shape of some other tensor (e.g., the second argument to Reshape is a
244   // <shape> tensor, where each element of the shape tensor is a dimension of
245   // the target tensor).  It returns in <result> a shape for that input.
246   //
247   // Unlike simply resolving node.input[dst_idx] to a constant and then
248   // converting that to a shape, this function can return a partial shape. This
249   // is useful for cases where the shape tensor is only partially defined, such
250   // as with calls for: reshape(x, shape(y)) where shape(y) is partially
251   // defined.
252   //
253   // The implementation has op implementations for ops commonly called on shape
254   // tensors, and the implementations are specialized to shape tensors (namely,
255   // the output is a vector).
256   //
257   // <target_context> is used when creating new DimensionHandle and ShapeHandle
258   // objects.
259   //
260   // Optionally, if 'node' is in a nested function, the 'InferenceContext' for
261   // the call op of the function can be passed as 'outer_context' (pass nullptr
262   // otherwise). This gets used to perform constant propagation across Arg nodes
263   // by requesting the constant of value of the incoming tensor from the
264   // 'outer_context'.
265   Status ConstantPartialShape(shape_inference::InferenceContext* target_context,
266                               const Node* node, int dst_idx,
267                               shape_inference::ShapeHandle* result,
268                               shape_inference::InferenceContext* outer_context);
269 
270   // Implementation of ConstantPartialShape for StridedSlice nodes.
271   //
272   // Optionally, if 'node' is in a nested function, the 'InferenceContext' for
273   // the call op of the function can be passed as 'outer_context' (pass nullptr
274   // otherwise). This gets used to perform constant propagation across Arg nodes
275   // by requesting the constant of value of the incoming tensor from the
276   // 'outer_context'.
277   Status PartialStridedSliceShape(
278       Node* slice_node, shape_inference::InferenceContext* ctx,
279       shape_inference::ShapeHandle* result,
280       shape_inference::InferenceContext* outer_context);
281 
282   // Runs the shape function registered for the node's op type.
283   //
284   // Optionally, if 'node' is in a nested function, the 'InferenceContext' for
285   // the call op of the function can be passed as 'outer_context' (pass nullptr
286   // otherwise). This gets used to perform constant propagation across Arg nodes
287   // by requesting the constant of value of the incoming tensor from the
288   // 'outer_context'.
289   Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data,
290                     ExtendedInferenceContext* ec,
291                     shape_inference::InferenceContext* outer_context = nullptr);
292 
293   int32 graph_def_version_;
294   const OpRegistryInterface* const ops_registry_;
295 
296   // The lifetime of the tensors are bound to the runner, so it should be the
297   // deleted after the tensors.
298   GraphRunner graph_runner_;
299 
300   // Stores a map from a node to its ExtendedInferenceContext.
301   absl::flat_hash_map<const Node*, std::unique_ptr<ExtendedInferenceContext>,
302                       hash<const Node*>>
303       node_to_context_;
304 
305   // Holds a cache from 'tensor name' to the tensor that is
306   // evaluatable as a constant expression.  This reduces repeated
307   // execution of the entire constant subgraph as a graph is being
308   // built up.  This could be changed to some kind of size-based LRU
309   // cache to avoid consuming too much memory, if that eventually
310   // becomes a concern.
311   //
312   // Only tensors less than 1KiB are currently stored in the cache.
313   static constexpr int64_t kMaxTensorSize = 1024;
314   std::unordered_map<string, Tensor> const_tensor_map_;
315 
316   bool require_shape_inference_fns_ = true;
317   bool disable_constant_propagation_ = false;
318 
319   // Function library is optional, but has to be set to enable function
320   // shape inference.
321   const tensorflow::FunctionLibraryDefinition* function_library_ = nullptr;
322 
323   // Cache the graph corresponding to each function definition for which shapes
324   // are refined.
325   absl::flat_hash_map<const FunctionDef*, std::unique_ptr<const Graph>,
326                       hash<const FunctionDef*>>
327       functions_;
328 
329   TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner);
330 };
331 
332 }  // namespace tensorflow
333 
334 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
335