• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
17 
18 #include <vector>
19 
20 #include "tensorflow/core/common_runtime/graph_runner.h"
21 #include "tensorflow/core/framework/function.pb.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/macros.h"
26 
27 namespace tensorflow {
28 namespace grappler {
29 class GraphProperties;
30 }
31 
32 // This class stores extra inference information in addition to
33 // InferenceContext, such as node input and output types.
34 class ExtendedInferenceContext {
35  public:
ExtendedInferenceContext(std::unique_ptr<shape_inference::InferenceContext> ic,const Node * node)36   ExtendedInferenceContext(
37       std::unique_ptr<shape_inference::InferenceContext> ic, const Node* node)
38       : inference_context_(std::move(ic)), op_(node->name()) {
39     input_types_.reserve(node->num_inputs());
40     for (int i = 0; i < node->num_inputs(); i++) {
41       input_types_.push_back(node->input_type(i));
42     }
43     output_types_.reserve(node->num_outputs());
44     for (int i = 0; i < node->num_outputs(); i++) {
45       output_types_.push_back(node->output_type(i));
46     }
47   }
48 
input_type(int64 idx)49   DataType input_type(int64 idx) const { return input_types_[idx]; }
output_type(int64 idx)50   DataType output_type(int64 idx) const { return output_types_[idx]; }
51 
get_context()52   shape_inference::InferenceContext* get_context() {
53     return inference_context_.get();
54   }
55 
op()56   std::string op() const { return op_; }
57 
58  private:
59   std::unique_ptr<shape_inference::InferenceContext> inference_context_;
60   std::string op_;
61   std::vector<DataType> input_types_;
62   std::vector<DataType> output_types_;
63 
64   TF_DISALLOW_COPY_AND_ASSIGN(ExtendedInferenceContext);
65 };
66 
67 // ShapeRefiner performs shape inference for TensorFlow Graphs.  It is
68 // responsible for instantiating InferenceContext objects for each
69 // Node in the Graph, and providing/storing the 'input_tensor' Tensors
70 // used by Shape Inference functions, when available at graph
71 // construction time.
72 class ShapeRefiner {
73  public:
74   ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops);
75 
76   // Same as ShapeRefiner(versions.producer(), ops)
77   ShapeRefiner(const VersionDef& versions, const OpRegistryInterface* ops);
78 
79   ~ShapeRefiner();
80 
81   // Performs validation of 'node' and runs 'node's shape function,
82   // storing its shape outputs.
83   //
84   // All inputs of 'node' must be added to ShapeRefiner prior to
85   // adding 'node'.
86   //
87   // Returns an error if:
88   //  - the shape function for 'node' was not registered.
89   //  - 'node' was added before its inputs.
90   //  - The shape inference function returns an error.
91   Status AddNode(const Node* node);
92 
93   // Sets 'node's 'output_port' output to have shape 'shape'.
94   //
95   // Returns an error if 'node' was not previously added to this
96   // object, if 'output_port' is invalid, or if 'shape' is
97   // not compatible with the existing shape of the output.
98   Status SetShape(const Node* node, int output_port,
99                   shape_inference::ShapeHandle shape);
100 
101   // Update the input shapes of node in case the shapes of the fan-ins of 'node'
102   // have themselves been modified (For example, in case of incremental shape
103   // refinement). If 'relax' is true, a new shape with the broadest set of
104   // information will be set as the new input (see InferenceContext::RelaxInput
105   // for full details and examples). Sets refined to true if any shapes have
106   // changed (in their string representations). Note that shapes may have been
107   // updated to newer versions (but with identical string representations) even
108   // if <*refined> is set to false.
109   Status UpdateNode(const Node* node, bool relax, bool* refined);
110 
111   // Returns the InferenceContext for 'node', if present.
GetContext(const Node * node)112   shape_inference::InferenceContext* GetContext(const Node* node) const {
113     auto it = node_to_context_.find(node);
114     if (it == node_to_context_.end()) {
115       return nullptr;
116     }
117     return it->second->get_context();
118   }
119 
120   // Returns the ExtendedInferenceContext for 'node', if present.
GetExtendedContext(const Node * node)121   ExtendedInferenceContext* GetExtendedContext(const Node* node) const {
122     auto it = node_to_context_.find(node);
123     if (it == node_to_context_.end()) {
124       return nullptr;
125     }
126     return it->second.get();
127   }
128 
129   // Getters and setters for graph_def_version_.
graph_def_version()130   int32 graph_def_version() const { return graph_def_version_; }
set_graph_def_version(int32 version)131   void set_graph_def_version(int32 version) { graph_def_version_ = version; }
132 
set_require_shape_inference_fns(bool require_shape_inference_fns)133   void set_require_shape_inference_fns(bool require_shape_inference_fns) {
134     require_shape_inference_fns_ = require_shape_inference_fns;
135   }
set_disable_constant_propagation(bool disable)136   void set_disable_constant_propagation(bool disable) {
137     disable_constant_propagation_ = disable;
138   }
139 
140   // Set function library to enable function shape inference.
141   // Without function library, function inference always yields unknown shapes.
142   // With this enabled, shape inference can take more time since it descends
143   // into all function calls. It doesn't do inference once for each function
144   // definition, but once for each function call.
145   // The function library must outlive the shape refiner.
set_function_library_for_shape_inference(const tensorflow::FunctionLibraryDefinition * lib)146   void set_function_library_for_shape_inference(
147       const tensorflow::FunctionLibraryDefinition* lib) {
148     function_library_ = lib;
149   }
150 
function_shape_inference_supported()151   bool function_shape_inference_supported() const {
152     return function_library_ != nullptr;
153   }
154 
155  private:
156   friend class ShapeRefinerTest;
157   friend class ::tensorflow::grappler::GraphProperties;
158 
159   // Returns true if the ranks and all dimensions of <s0> and <s1> are either
160   // equal in value or both unknown.
161   static bool SameDefinedShape(shape_inference::InferenceContext* c,
162                                shape_inference::ShapeHandle s0,
163                                shape_inference::ShapeHandle s1);
164 
165   // Returns true if the shapes and types stored in <*existing> are identical in
166   // value to the shapes and types in <*updated>.
167   static bool IsUpdatedShapesOrTypes(
168       shape_inference::InferenceContext* c,
169       const std::vector<shape_inference::ShapeAndType>& existing,
170       const std::vector<shape_inference::ShapeAndType>& updated);
171 
172   // Performs shape inference for the given function_def within the
173   // given outer_context. Internally it instantiates the function as a graph
174   // and runs shape inference recursively on it with the input shapes provided
175   // by the outer_context.
176   //
177   // Returns an error if:
178   // - number of inputs/outputs on outer_context doesn't match the function_def
179   //
180   // On success:
181   // - outer_context will contain output shapes inferred from input shapes
182   Status InferShapesForFunction(const FunctionDef* function_def,
183                                 AttrSlice attributes,
184                                 ExtendedInferenceContext* outer_context);
185 
186   // Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge
187   // value can be evaluated, 'evaluated' is set to true and the value returned
188   // in 'result'. Otherwise 'evaluated' is set to false.
189   Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx,
190                                        bool* evaluated, Tensor* result);
191 
192   // Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input
193   // tensors. The caller is responsible for checking that the specified edge is
194   // scalar and int32 or int64.
195   Status EvaluateConstantIntScalarEdge(const Node* node, int dst_idx,
196                                        bool* evaluated, int64* result);
197 
198   // This function tries to materialize as much information about the 'node''s
199   // dst_idx input as a statically computable shape, and the result may be
200   // partially known, depending on what is statically inferable.
201   //
202   // This is called when node.input[dst_idx] is a tensor that is used to define
203   // the shape of some other tensor (e.g., the second argument to Reshape is a
204   // <shape> tensor, where each element of the shape tensor is a dimension of
205   // the target tensor).  It returns in <result> a shape for that input.
206   //
207   // Unlike simply resolving node.input[dst_idx] to a constant and then
208   // converting that to a shape, this function can return a partial shape. This
209   // is useful for cases where the shape tensor is only partially defined, such
210   // as with calls for: reshape(x, shape(y)) where shape(y) is partially
211   // defined.
212   //
213   // The implementation has op implementations for ops commonly called on shape
214   // tensors, and the implementations are specialized to shape tensors (namely,
215   // the output is a vector).
216   //
217   // <target_context> is used when creating new DimensionHandle and ShapeHandle
218   // objects.
219   Status ConstantPartialShape(shape_inference::InferenceContext* target_context,
220                               const Node* node, int dst_idx,
221                               shape_inference::ShapeHandle* result);
222 
223   // Implementation of ConstantPartialShape for StridedSlice nodes.
224   Status PartialStridedSliceShape(Node* slice_node,
225                                   shape_inference::InferenceContext* ctx,
226                                   shape_inference::ShapeHandle* result);
227 
228   Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data,
229                     ExtendedInferenceContext* ec);
230 
231   int32 graph_def_version_;
232   const OpRegistryInterface* const ops_registry_;
233 
234   // The lifetime of the tensors are bound to the runner, so it should be the
235   // deleted after the tensors.
236   GraphRunner graph_runner_;
237 
238   // Stores a map from a node to its ExtendedInferenceContext.
239   std::unordered_map<const Node*, std::unique_ptr<ExtendedInferenceContext>>
240       node_to_context_;
241 
242   // Holds a cache from 'tensor name' to the tensor that is
243   // evaluatable as a constant expression.  This reduces repeated
244   // execution of the entire constant subgraph as a graph is being
245   // built up.  This could be changed to some kind of size-based LRU
246   // cache to avoid consuming too much memory, if that eventually
247   // becomes a concern.
248   //
249   // Only tensors less than 1KiB are currently stored in the cache.
250   static constexpr int64 kMaxTensorSize = 1024;
251   std::unordered_map<string, Tensor> const_tensor_map_;
252 
253   bool require_shape_inference_fns_ = true;
254   bool disable_constant_propagation_ = false;
255 
256   // Function library is optional, but has to be set to enable function
257   // shape inference.
258   const tensorflow::FunctionLibraryDefinition* function_library_ = nullptr;
259 
260   // Cache the graph corresponding to each functin definition for which shapes
261   // are refined.
262   std::unordered_map<const FunctionDef*, std::unique_ptr<const Graph>>
263       functions_;
264 
265   TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner);
266 };
267 
268 }  // namespace tensorflow
269 
270 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
271