• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
17 
18 #include "tensorflow/core/framework/tensor.pb.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/mutable_graph_view.h"
23 #include "tensorflow/core/grappler/op_types.h"
24 #include "tensorflow/core/grappler/utils.h"
25 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 
28 namespace tensorflow {
29 namespace grappler {
30 
31 // This optimizer first rewrites Prod(Shape(x)) into Size(x). It then uses
32 // symbolic shapes to simplify Div(Size(x), Size(y)) in the case that x and y
33 // share symbolic shapes that are unknown but known to be identical, e.g. we can
34 // deduce that Div(Size([2,?,2]) Size([1,?,2])) is 2 if the two unknown
35 // dimensions are known to be identical. This can be inferred if they share the
36 // same symbolic representation (negative integer dimension size).
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)37 Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
38                                 GraphDef* optimized_graph) {
39   // Do a quick check to determine if we can skip this optimizer.
40   bool can_optimize = false;
41   bool has_div = false;
42   bool has_size = false;
43   bool has_shape = false;
44   bool has_prod = false;
45   auto is_int = [](const NodeDef& node) -> bool {
46     return node.attr().at("T").type() == DT_INT32 ||
47            node.attr().at("T").type() == DT_INT64;
48   };
49   for (const NodeDef& node : item.graph.node()) {
50     if (IsShape(node)) {
51       has_shape = true;
52     } else if (IsProd(node) && is_int(node)) {
53       has_prod = true;
54     } else if (IsDiv(node) && is_int(node)) {
55       has_div = true;
56     } else if (IsSize(node)) {
57       has_size = true;
58     }
59     if ((has_shape && has_prod) || (has_div && has_size)) {
60       can_optimize = true;
61       break;
62     }
63   }
64   if (!can_optimize) {
65     return errors::Aborted("Nothing to do.");
66   }
67 
68   *optimized_graph = item.graph;
69   GraphProperties properties(item);
70   bool inferred_properties = false;
71   {
72     MutableGraphView graph(optimized_graph);
73     // The product of all the dimensions in a tensor shape can be expressed more
74     // simply as the size of the tensor.
75     for (auto& node : *optimized_graph->mutable_node()) {
76       if (!IsShape(node)) {
77         continue;
78       }
79       for (MutableGraphView::InputPort fanout :
80            graph.GetFanout(MutableGraphView::OutputPort(&node, 0))) {
81         if (fanout.node->op() != "Prod") {
82           continue;
83         }
84         if (fanout.node->attr().count("keep_dims") != 0 &&
85             fanout.node->attr().at("keep_dims").b()) {
86           // Keeping the reduced dimensions won't result in a scalar, so we
87           // can't rewrite the whole expression directly as a Size operation.
88           continue;
89         }
90         const MutableGraphView::OutputPort reduce_indices =
91             graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1));
92         if (!inferred_properties) {
93           // Infer properties lazily in case they are not needed.
94           TF_RETURN_IF_ERROR(
95               properties.InferStatically(/*assume_valid_feeds=*/false,
96                                          /*aggressive_shape_inference=*/false,
97                                          /*include_tensor_values=*/false));
98           inferred_properties = true;
99         }
100         const auto& prop =
101             properties.GetOutputProperties(reduce_indices.node->name());
102         const int prop_size = prop.size();
103         if (prop_size <= reduce_indices.port_id) {
104           continue;
105         }
106         const TensorShapeProto& reduction_indices_shape =
107             prop[reduce_indices.port_id].shape();
108         if (NumCoefficients(reduction_indices_shape) == 1) {
109           const auto& input_props = properties.GetInputProperties(node.name());
110           if (input_props.size() != 1) {
111             continue;
112           }
113           // Rewrite the reduction of the shape dimensions as a Size operation.
114           NodeDef size_node(*fanout.node);
115           const DataType type = input_props[0].dtype();
116           size_node.set_op("Size");
117           size_node.set_input(0, node.input(0));
118           size_node.set_input(1, AsControlDependency(node));
119           size_node.mutable_attr()->erase("Tidx");
120           size_node.mutable_attr()->erase("keep_dims");
121           (*size_node.mutable_attr())["out_type"] = fanout.node->attr().at("T");
122           (*size_node.mutable_attr())["T"].set_type(type);
123 
124           // The corresponding Size kernel might not exist on the device where
125           // Prod was placed, so assign the Size kernel to the same device as
126           // the input.
127           size_node.set_device(node.device());
128 
129           // In the unlikely even that "Size" is not registered on the input
130           // device, skip the optimization.
131           Status s = IsKernelRegisteredForNode(size_node);
132           if (!s.ok()) {
133             continue;
134           }
135 
136           fanout.node->Swap(&size_node);
137         }
138       }
139     }
140   }
141   {
142     MutableGraphView graph(optimized_graph);
143     for (auto& node : *optimized_graph->mutable_node()) {
144       // Try to convert the ratio of 2 symbolic tensor sizes into a constant.
145       // This is possible whenever the symbolic dimensions in the numerator and
146       // denominator cancel each other.
147       if (node.op() == "Div") {
148         const MutableGraphView::OutputPort input1 =
149             graph.GetRegularFanin(MutableGraphView::InputPort(&node, 0));
150         const MutableGraphView::OutputPort input2 =
151             graph.GetRegularFanin(MutableGraphView::InputPort(&node, 1));
152         if (input1.node == nullptr || input2.node == nullptr) continue;
153         if (!IsSize(*input1.node) || !IsSize(*input2.node)) {
154           continue;
155         }
156         if (!inferred_properties) {
157           // Infer properties lazily in case they are not needed.
158           TF_RETURN_IF_ERROR(
159               properties.InferStatically(/*assume_valid_feeds=*/false,
160                                          /*aggressive_shape_inference=*/false,
161                                          /*include_tensor_values=*/false));
162           inferred_properties = true;
163         }
164         const auto& prop1 = properties.GetInputProperties(input1.node->name());
165         const auto& prop2 = properties.GetInputProperties(input2.node->name());
166         if (prop1.size() != 1 || prop2.size() != 1) {
167           continue;
168         }
169         const TensorShapeProto& shape1 = prop1[0].shape();
170         const TensorShapeProto& shape2 = prop2[0].shape();
171         int64_t result = ComputeSizeRatio(shape1, shape2);
172         if (result >= 0) {
173           // Replace div with constant.
174           node.set_op("Const");
175           DataType dtype = node.attr().at("T").type();
176           node.mutable_attr()->erase("T");
177           (*node.mutable_attr())["dtype"].set_type(dtype);
178           TensorProto* t = (*node.mutable_attr())["value"].mutable_tensor();
179           t->set_dtype(dtype);
180           *t->mutable_tensor_shape() = TensorShapeProto();
181           if (dtype == DT_INT32) {
182             t->add_int_val(result);
183           } else {
184             t->add_int64_val(result);
185           }
186           node.set_input(0, AsControlDependency(node.input(0)));
187           node.set_input(1, AsControlDependency(node.input(1)));
188         }
189       }
190     }
191   }
192   return OkStatus();
193 }
194 
195 }  // end namespace grappler
196 }  // namespace tensorflow
197