1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/shape_inference.h"
17
18 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
19 #include "tensorflow/core/common_runtime/shape_refiner.h"
20 #include "tensorflow/core/framework/node_def_util.h"
21 #include "tensorflow/core/framework/shape_inference.h"
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/graph/algorithm.h"
25 #include "tensorflow/core/util/dump_graph.h"
26
27 namespace tensorflow {
28
29 namespace {
30
31 // Converts a shape inference handle to a PartialTensorShape.
ShapeHandleToTensorShape(shape_inference::InferenceContext * context,const shape_inference::ShapeHandle & handle,PartialTensorShape * shape)32 Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
33 const shape_inference::ShapeHandle& handle,
34 PartialTensorShape* shape) {
35 // The default is already unknown
36 if (!context->RankKnown(handle)) return Status::OK();
37
38 std::vector<int64> dims(context->Rank(handle));
39 for (int32 i = 0, end = dims.size(); i < end; ++i) {
40 dims[i] = context->Value(context->Dim(handle, i));
41 }
42 return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
43 }
44
PropagateShapes(Graph * graph,const std::map<int,InferredShape> & arg_shapes,const std::vector<BackEdgeHelper::BackEdge> & back_edges,ShapeRefiner * shape_refiner)45 Status PropagateShapes(Graph* graph,
46 const std::map<int, InferredShape>& arg_shapes,
47 const std::vector<BackEdgeHelper::BackEdge>& back_edges,
48 ShapeRefiner* shape_refiner) {
49 std::map<const Node*, const Node*> merge_to_next_iteration;
50 for (const auto& e : back_edges) {
51 if (e.src->IsNextIteration() && e.dst->IsMerge()) {
52 merge_to_next_iteration[e.dst] = e.src;
53 }
54 }
55
56 // Visits the nodes in topological order (reverse post-order), inferring
57 // shapes.
58 // TODO(phawkins): handle cyclic graphs.
59 std::vector<Node*> order;
60 GetReversePostOrder(*graph, &order);
61
62 for (Node* n : order) {
63 // Ignore the status returned by the shape_refiner. We want the best effort
64 // shapes, even if no shape function is registered for a node.
65 Status status = shape_refiner->AddNode(n);
66 if (!status.ok()) {
67 VLOG(1) << "Shape inference failed for node " << n->name() << ": "
68 << status;
69 } else {
70 shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
71 for (int i = 0; i < n->num_outputs(); i++) {
72 shape_inference::ShapeHandle handle = context->output(i);
73 VLOG(4) << "Output " << i << " for node " << n->name() << ": "
74 << context->DebugString(handle);
75 }
76 }
77
78 if (n->type_string() == "_Arg") {
79 int index;
80 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
81 auto it = arg_shapes.find(index);
82 if (it != arg_shapes.end()) {
83 const InferredShape& arg_shape = it->second;
84 shape_inference::InferenceContext* context =
85 shape_refiner->GetContext(n);
86
87 if (arg_shape.handle_type != DT_INVALID) {
88 shape_inference::ShapeHandle handle;
89 TF_RETURN_IF_ERROR(context->MakeShapeFromPartialTensorShape(
90 arg_shape.handle_shape, &handle));
91
92 // Sets the shape and type of the variable's value.
93 context->set_output_handle_shapes_and_types(
94 0, std::vector<shape_inference::ShapeAndType>{
95 {handle, arg_shape.handle_type}});
96 }
97
98 shape_inference::ShapeHandle handle;
99 TF_RETURN_IF_ERROR(
100 context->MakeShapeFromPartialTensorShape(arg_shape.shape, &handle));
101 TF_RETURN_IF_ERROR(shape_refiner->SetShape(n, 0, handle));
102 }
103 }
104
105 // Sometimes we have VariableShape nodes in while loop (after Enter nodes).
106 // They won't be constant-folded because TensorFlow constant folding does
107 // not handle Enter nodes (and thus does not handle any nodes after Enter
108 // nodes). We try to replace such VariableShape nodes with Const nodes here.
109 if (n->type_string() == "VariableShape") {
110 shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
111 auto handle_shapes_and_types = context->input_handle_shapes_and_types(0);
112 if (handle_shapes_and_types && !handle_shapes_and_types->empty()) {
113 shape_inference::ShapeHandle handle =
114 handle_shapes_and_types->at(0).shape;
115 TensorShapeProto shape_proto;
116 context->ShapeHandleToProto(handle, &shape_proto);
117 if (!shape_proto.unknown_rank()) {
118 NodeDef const_def;
119 const_def.set_op("Const");
120 Node* var_node;
121 TF_RETURN_IF_ERROR(n->input_node(0, &var_node));
122 const_def.set_name(
123 graph->NewName(absl::StrCat("var_shape_", var_node->name())));
124 DataType dtype = n->output_type(0);
125 AddNodeAttr("dtype", dtype, &const_def);
126 TensorProto value;
127 value.set_dtype(dtype);
128 value.mutable_tensor_shape()->add_dim()->set_size(
129 shape_proto.dim_size());
130 for (const auto& dim : shape_proto.dim()) {
131 if (dtype == DT_INT32) {
132 value.add_int_val(dim.size());
133 } else {
134 value.add_int64_val(dim.size());
135 }
136 }
137 AddNodeAttr("value", value, &const_def);
138 for (auto const& attr : n->attrs()) {
139 if (*attr.first.begin() == '_') {
140 AddNodeAttr(attr.first, attr.second, &const_def);
141 }
142 }
143
144 Status s;
145 Node* const_node = graph->AddNode(const_def, &s);
146 TF_RETURN_IF_ERROR(s);
147
148 graph->AddControlEdge(var_node, const_node);
149 std::vector<const Edge*> out_edges(n->out_edges().begin(),
150 n->out_edges().end());
151 for (const Edge* e : out_edges) {
152 if (e->IsControlEdge()) {
153 graph->AddControlEdge(const_node, e->dst());
154 graph->RemoveEdge(e);
155 } else {
156 Node* dst = e->dst();
157 int dst_input = e->dst_input();
158 graph->RemoveEdge(e);
159 graph->AddEdge(const_node, 0, dst, dst_input);
160 }
161 }
162 }
163 }
164 }
165
166 // Merge node causes a loop so we remove NextIteration->Merge edge before
167 // performing shape inference. But removing those edges also prevents us
168 // from inferring output shape for Merge node (we need shapes for all its
169 // inputs).
170 // For loop invariant resource input's Merge node, we set output resource
171 // shape as Enter node's resource shape.
172 // TODO(b/129367850): clean this up.
173 if (n->IsMerge() && n->output_type(0) == DT_RESOURCE) {
174 // Check if this is a loop invariant input's Merge node. We do it by
175 // checking if corresponding NextIteration node comes from Switch node
176 // directly.
177 auto iter = merge_to_next_iteration.find(n);
178 if (iter != merge_to_next_iteration.end()) {
179 const Node *next_iter = iter->second, *node = next_iter;
180 do {
181 TF_RETURN_IF_ERROR(node->input_node(0, &node));
182 } while (node->IsIdentity());
183 const Node* switch_input;
184 bool is_loop_invariant = node->IsSwitch() &&
185 node->input_node(0, &switch_input).ok() &&
186 switch_input == n;
187 if (is_loop_invariant) {
188 shape_inference::InferenceContext* context =
189 shape_refiner->GetContext(n);
190 for (int i = 0; i < n->num_inputs(); i++) {
191 const Node* input_node;
192 if (n->input_node(i, &input_node).ok()) {
193 auto shapes_and_types = context->input_handle_shapes_and_types(i);
194 if (shapes_and_types) {
195 context->set_output_handle_shapes_and_types(0,
196 *shapes_and_types);
197 }
198 break;
199 }
200 }
201 }
202 }
203 }
204 }
205 return Status::OK();
206 }
207
208 // Store the shapes of the output tensors in a map
StoreOutputShapes(const Graph & graph,const ShapeRefiner & shape_refiner,GraphShapeInfo * shape_info)209 Status StoreOutputShapes(const Graph& graph, const ShapeRefiner& shape_refiner,
210 GraphShapeInfo* shape_info) {
211 for (const Node* node : graph.nodes()) {
212 shape_inference::InferenceContext* context = shape_refiner.GetContext(node);
213 if (!context) continue;
214
215 auto& outputs = (*shape_info)[node->name()];
216 outputs.resize(context->num_outputs());
217 for (int i = 0; i < context->num_outputs(); ++i) {
218 auto& output = outputs[i];
219 TF_RETURN_IF_ERROR(
220 ShapeHandleToTensorShape(context, context->output(i), &output.shape));
221
222 const auto* handle_shapes_and_types =
223 context->output_handle_shapes_and_types(i);
224 if (handle_shapes_and_types != nullptr) {
225 if (handle_shapes_and_types->size() == 1) {
226 TF_RETURN_IF_ERROR(ShapeHandleToTensorShape(
227 context, (*handle_shapes_and_types)[0].shape,
228 &output.handle_shape));
229 output.handle_type = (*handle_shapes_and_types)[0].dtype;
230 } else {
231 // otherwise, it may be resource like a Queue, which can have
232 // multiple shapes and types represented by a single handle.
233 }
234 }
235 VLOG(4) << node->name() << " output " << i << " shape"
236 << output.shape.DebugString() << " handle_type "
237 << DataTypeString(output.handle_type) << " handle_shape "
238 << output.handle_shape.DebugString();
239 }
240 }
241 return Status::OK();
242 }
243
244 } // namespace
245
InferShapes(Graph * graph,const std::map<int,InferredShape> & arg_shapes,const tensorflow::FunctionLibraryDefinition * fnlib_def,GraphShapeInfo * shape_info)246 Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
247 const tensorflow::FunctionLibraryDefinition* fnlib_def,
248 GraphShapeInfo* shape_info) {
249 ShapeRefiner shape_refiner(graph->versions(), graph->op_registry());
250 shape_refiner.set_require_shape_inference_fns(false);
251 // TODO(dlibenzi): Verify if it is worth trying to infer shaped within
252 // functions. Some functions can be called at multiple locations with
253 // difference shapes, which will trigger a shape inference based on the
254 // arguments passed at the first call.
255 // shape_refiner.set_function_library_for_shape_inference(fnlib_def);
256
257 // ShapeRefiner requires that all inputs of a node are present when
258 // ShapeRefiner::AddNode is called. To get at least some shape information in
259 // loops, we temporarily remove loop backedges and add them back again after
260 // the shape inference is complete.
261 BackEdgeHelper back_edge;
262 TF_RETURN_IF_ERROR(back_edge.Remove(graph));
263 TF_RETURN_IF_ERROR(PropagateShapes(graph, arg_shapes,
264 back_edge.RemovedEdges(), &shape_refiner));
265 TF_RETURN_IF_ERROR(back_edge.Replace());
266
267 // Currently information does not flow "backward" from consumers to producers
268 // in the shape inference, but we consume the shapes in a second pass in case
269 // backward information flow is added in the future.
270 return StoreOutputShapes(*graph, shape_refiner, shape_info);
271 }
272
MergeInferredShapes(const InferredShape & a,const InferredShape & b)273 xla::StatusOr<InferredShape> MergeInferredShapes(const InferredShape& a,
274 const InferredShape& b) {
275 InferredShape result;
276 TF_RETURN_IF_ERROR(a.shape.MergeWith(b.shape, &result.shape));
277
278 if (a.handle_type == DT_INVALID) {
279 result.handle_type = b.handle_type;
280 } else if (b.handle_type == DT_INVALID) {
281 result.handle_type = a.handle_type;
282 } else if (a.handle_type == b.handle_type) {
283 result.handle_type = a.handle_type;
284 } else {
285 return errors::InvalidArgument(
286 "Mismatched resource types: ", DataTypeString(a.handle_type), " vs. ",
287 DataTypeString(b.handle_type));
288 }
289 TF_RETURN_IF_ERROR(
290 a.handle_shape.MergeWith(b.handle_shape, &result.handle_shape));
291 return result;
292 }
293
294 } // namespace tensorflow
295