• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/shape_inference.h"
17 
18 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
19 #include "tensorflow/core/common_runtime/shape_refiner.h"
20 #include "tensorflow/core/framework/node_def_util.h"
21 #include "tensorflow/core/framework/shape_inference.h"
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/graph/algorithm.h"
25 #include "tensorflow/core/util/dump_graph.h"
26 
27 namespace tensorflow {
28 
29 namespace {
30 
31 // Converts a shape inference handle to a PartialTensorShape.
ShapeHandleToTensorShape(shape_inference::InferenceContext * context,const shape_inference::ShapeHandle & handle,PartialTensorShape * shape)32 Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
33                                 const shape_inference::ShapeHandle& handle,
34                                 PartialTensorShape* shape) {
35   // The default is already unknown
36   if (!context->RankKnown(handle)) return Status::OK();
37 
38   std::vector<int64> dims(context->Rank(handle));
39   for (int32 i = 0, end = dims.size(); i < end; ++i) {
40     dims[i] = context->Value(context->Dim(handle, i));
41   }
42   return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
43 }
44 
PropagateShapes(Graph * graph,const std::map<int,InferredShape> & arg_shapes,const std::vector<BackEdgeHelper::BackEdge> & back_edges,ShapeRefiner * shape_refiner)45 Status PropagateShapes(Graph* graph,
46                        const std::map<int, InferredShape>& arg_shapes,
47                        const std::vector<BackEdgeHelper::BackEdge>& back_edges,
48                        ShapeRefiner* shape_refiner) {
49   std::map<const Node*, const Node*> merge_to_next_iteration;
50   for (const auto& e : back_edges) {
51     if (e.src->IsNextIteration() && e.dst->IsMerge()) {
52       merge_to_next_iteration[e.dst] = e.src;
53     }
54   }
55 
56   // Visits the nodes in topological order (reverse post-order), inferring
57   // shapes.
58   // TODO(phawkins): handle cyclic graphs.
59   std::vector<Node*> order;
60   GetReversePostOrder(*graph, &order);
61 
62   for (Node* n : order) {
63     // Ignore the status returned by the shape_refiner. We want the best effort
64     // shapes, even if no shape function is registered for a node.
65     Status status = shape_refiner->AddNode(n);
66     if (!status.ok()) {
67       VLOG(1) << "Shape inference failed for node " << n->name() << ": "
68               << status;
69     } else {
70       shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
71       for (int i = 0; i < n->num_outputs(); i++) {
72         shape_inference::ShapeHandle handle = context->output(i);
73         VLOG(4) << "Output " << i << " for node " << n->name() << ": "
74                 << context->DebugString(handle);
75       }
76     }
77 
78     if (n->type_string() == "_Arg") {
79       int index;
80       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
81       auto it = arg_shapes.find(index);
82       if (it != arg_shapes.end()) {
83         const InferredShape& arg_shape = it->second;
84         shape_inference::InferenceContext* context =
85             shape_refiner->GetContext(n);
86 
87         if (arg_shape.handle_type != DT_INVALID) {
88           shape_inference::ShapeHandle handle;
89           TF_RETURN_IF_ERROR(context->MakeShapeFromPartialTensorShape(
90               arg_shape.handle_shape, &handle));
91 
92           // Sets the shape and type of the variable's value.
93           context->set_output_handle_shapes_and_types(
94               0, std::vector<shape_inference::ShapeAndType>{
95                      {handle, arg_shape.handle_type}});
96         }
97 
98         shape_inference::ShapeHandle handle;
99         TF_RETURN_IF_ERROR(
100             context->MakeShapeFromPartialTensorShape(arg_shape.shape, &handle));
101         TF_RETURN_IF_ERROR(shape_refiner->SetShape(n, 0, handle));
102       }
103     }
104 
105     // Sometimes we have VariableShape nodes in while loop (after Enter nodes).
106     // They won't be constant-folded because TensorFlow constant folding does
107     // not handle Enter nodes (and thus does not handle any nodes after Enter
108     // nodes). We try to replace such VariableShape nodes with Const nodes here.
109     if (n->type_string() == "VariableShape") {
110       shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
111       auto handle_shapes_and_types = context->input_handle_shapes_and_types(0);
112       if (handle_shapes_and_types && !handle_shapes_and_types->empty()) {
113         shape_inference::ShapeHandle handle =
114             handle_shapes_and_types->at(0).shape;
115         TensorShapeProto shape_proto;
116         context->ShapeHandleToProto(handle, &shape_proto);
117         if (!shape_proto.unknown_rank()) {
118           NodeDef const_def;
119           const_def.set_op("Const");
120           Node* var_node;
121           TF_RETURN_IF_ERROR(n->input_node(0, &var_node));
122           const_def.set_name(
123               graph->NewName(absl::StrCat("var_shape_", var_node->name())));
124           DataType dtype = n->output_type(0);
125           AddNodeAttr("dtype", dtype, &const_def);
126           TensorProto value;
127           value.set_dtype(dtype);
128           value.mutable_tensor_shape()->add_dim()->set_size(
129               shape_proto.dim_size());
130           for (const auto& dim : shape_proto.dim()) {
131             if (dtype == DT_INT32) {
132               value.add_int_val(dim.size());
133             } else {
134               value.add_int64_val(dim.size());
135             }
136           }
137           AddNodeAttr("value", value, &const_def);
138           for (auto const& attr : n->attrs()) {
139             if (*attr.first.begin() == '_') {
140               AddNodeAttr(attr.first, attr.second, &const_def);
141             }
142           }
143 
144           Status s;
145           Node* const_node = graph->AddNode(const_def, &s);
146           TF_RETURN_IF_ERROR(s);
147 
148           graph->AddControlEdge(var_node, const_node);
149           std::vector<const Edge*> out_edges(n->out_edges().begin(),
150                                              n->out_edges().end());
151           for (const Edge* e : out_edges) {
152             if (e->IsControlEdge()) {
153               graph->AddControlEdge(const_node, e->dst());
154               graph->RemoveEdge(e);
155             } else {
156               Node* dst = e->dst();
157               int dst_input = e->dst_input();
158               graph->RemoveEdge(e);
159               graph->AddEdge(const_node, 0, dst, dst_input);
160             }
161           }
162         }
163       }
164     }
165 
166     // Merge node causes a loop so we remove NextIteration->Merge edge before
167     // performing shape inference. But removing those edges also prevents us
168     // from inferring output shape for Merge node (we need shapes for all its
169     // inputs).
170     // For loop invariant resource input's Merge node, we set output resource
171     // shape as Enter node's resource shape.
172     // TODO(b/129367850): clean this up.
173     if (n->IsMerge() && n->output_type(0) == DT_RESOURCE) {
174       // Check if this is a loop invariant input's Merge node. We do it by
175       // checking if corresponding NextIteration node comes from Switch node
176       // directly.
177       auto iter = merge_to_next_iteration.find(n);
178       if (iter != merge_to_next_iteration.end()) {
179         const Node *next_iter = iter->second, *node = next_iter;
180         do {
181           TF_RETURN_IF_ERROR(node->input_node(0, &node));
182         } while (node->IsIdentity());
183         const Node* switch_input;
184         bool is_loop_invariant = node->IsSwitch() &&
185                                  node->input_node(0, &switch_input).ok() &&
186                                  switch_input == n;
187         if (is_loop_invariant) {
188           shape_inference::InferenceContext* context =
189               shape_refiner->GetContext(n);
190           for (int i = 0; i < n->num_inputs(); i++) {
191             const Node* input_node;
192             if (n->input_node(i, &input_node).ok()) {
193               auto shapes_and_types = context->input_handle_shapes_and_types(i);
194               if (shapes_and_types) {
195                 context->set_output_handle_shapes_and_types(0,
196                                                             *shapes_and_types);
197               }
198               break;
199             }
200           }
201         }
202       }
203     }
204   }
205   return Status::OK();
206 }
207 
208 // Store the shapes of the output tensors in a map
StoreOutputShapes(const Graph & graph,const ShapeRefiner & shape_refiner,GraphShapeInfo * shape_info)209 Status StoreOutputShapes(const Graph& graph, const ShapeRefiner& shape_refiner,
210                          GraphShapeInfo* shape_info) {
211   for (const Node* node : graph.nodes()) {
212     shape_inference::InferenceContext* context = shape_refiner.GetContext(node);
213     if (!context) continue;
214 
215     auto& outputs = (*shape_info)[node->name()];
216     outputs.resize(context->num_outputs());
217     for (int i = 0; i < context->num_outputs(); ++i) {
218       auto& output = outputs[i];
219       TF_RETURN_IF_ERROR(
220           ShapeHandleToTensorShape(context, context->output(i), &output.shape));
221 
222       const auto* handle_shapes_and_types =
223           context->output_handle_shapes_and_types(i);
224       if (handle_shapes_and_types != nullptr) {
225         if (handle_shapes_and_types->size() == 1) {
226           TF_RETURN_IF_ERROR(ShapeHandleToTensorShape(
227               context, (*handle_shapes_and_types)[0].shape,
228               &output.handle_shape));
229           output.handle_type = (*handle_shapes_and_types)[0].dtype;
230         } else {
231           // otherwise, it may be resource like a Queue, which can have
232           // multiple shapes and types represented by a single handle.
233         }
234       }
235       VLOG(4) << node->name() << " output " << i << " shape"
236               << output.shape.DebugString() << " handle_type "
237               << DataTypeString(output.handle_type) << " handle_shape "
238               << output.handle_shape.DebugString();
239     }
240   }
241   return Status::OK();
242 }
243 
244 }  // namespace
245 
InferShapes(Graph * graph,const std::map<int,InferredShape> & arg_shapes,const tensorflow::FunctionLibraryDefinition * fnlib_def,GraphShapeInfo * shape_info)246 Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
247                    const tensorflow::FunctionLibraryDefinition* fnlib_def,
248                    GraphShapeInfo* shape_info) {
249   ShapeRefiner shape_refiner(graph->versions(), graph->op_registry());
250   shape_refiner.set_require_shape_inference_fns(false);
251   // TODO(dlibenzi): Verify if it is worth trying to infer shaped within
252   // functions. Some functions can be called at multiple locations with
253   // difference shapes, which will trigger a shape inference based on the
254   // arguments passed at the first call.
255   // shape_refiner.set_function_library_for_shape_inference(fnlib_def);
256 
257   // ShapeRefiner requires that all inputs of a node are present when
258   // ShapeRefiner::AddNode is called. To get at least some shape information in
259   // loops, we temporarily remove loop backedges and add them back again after
260   // the shape inference is complete.
261   BackEdgeHelper back_edge;
262   TF_RETURN_IF_ERROR(back_edge.Remove(graph));
263   TF_RETURN_IF_ERROR(PropagateShapes(graph, arg_shapes,
264                                      back_edge.RemovedEdges(), &shape_refiner));
265   TF_RETURN_IF_ERROR(back_edge.Replace());
266 
267   // Currently information does not flow "backward" from consumers to producers
268   // in the shape inference, but we consume the shapes in a second pass in case
269   // backward information flow is added in the future.
270   return StoreOutputShapes(*graph, shape_refiner, shape_info);
271 }
272 
MergeInferredShapes(const InferredShape & a,const InferredShape & b)273 xla::StatusOr<InferredShape> MergeInferredShapes(const InferredShape& a,
274                                                  const InferredShape& b) {
275   InferredShape result;
276   TF_RETURN_IF_ERROR(a.shape.MergeWith(b.shape, &result.shape));
277 
278   if (a.handle_type == DT_INVALID) {
279     result.handle_type = b.handle_type;
280   } else if (b.handle_type == DT_INVALID) {
281     result.handle_type = a.handle_type;
282   } else if (a.handle_type == b.handle_type) {
283     result.handle_type = a.handle_type;
284   } else {
285     return errors::InvalidArgument(
286         "Mismatched resource types: ", DataTypeString(a.handle_type), " vs. ",
287         DataTypeString(b.handle_type));
288   }
289   TF_RETURN_IF_ERROR(
290       a.handle_shape.MergeWith(b.handle_shape, &result.handle_shape));
291   return result;
292 }
293 
294 }  // namespace tensorflow
295