1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
17
18 #include "tensorflow/core/framework/tensor.pb.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/mutable_graph_view.h"
23 #include "tensorflow/core/grappler/op_types.h"
24 #include "tensorflow/core/grappler/utils.h"
25 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
26 #include "tensorflow/core/lib/core/errors.h"
27
28 namespace tensorflow {
29 namespace grappler {
30
31 // This optimizer first rewrites Prod(Shape(x)) into Size(x). It then uses
32 // symbolic shapes to simplify Div(Size(x), Size(y)) in the case that x and y
33 // share symbolic shapes that are unknown but known to be identical, e.g. we can
34 // deduce that Div(Size([2,?,2]) Size([1,?,2])) is 2 if the two unknown
35 // dimensions are known to be identical. This can be inferred if they share the
36 // same symbolic representation (negative integer dimension size).
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)37 Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
38 GraphDef* optimized_graph) {
39 // Do a quick check to determine if we can skip this optimizer.
40 bool can_optimize = false;
41 bool has_div = false;
42 bool has_size = false;
43 bool has_shape = false;
44 bool has_prod = false;
45 auto is_int = [](const NodeDef& node) -> bool {
46 return node.attr().at("T").type() == DT_INT32 ||
47 node.attr().at("T").type() == DT_INT64;
48 };
49 for (const NodeDef& node : item.graph.node()) {
50 if (IsShape(node)) {
51 has_shape = true;
52 } else if (IsProd(node) && is_int(node)) {
53 has_prod = true;
54 } else if (IsDiv(node) && is_int(node)) {
55 has_div = true;
56 } else if (IsSize(node)) {
57 has_size = true;
58 }
59 if ((has_shape && has_prod) || (has_div && has_size)) {
60 can_optimize = true;
61 break;
62 }
63 }
64 if (!can_optimize) {
65 return errors::Aborted("Nothing to do.");
66 }
67
68 *optimized_graph = item.graph;
69 GraphProperties properties(item);
70 bool inferred_properties = false;
71 {
72 MutableGraphView graph(optimized_graph);
73 // The product of all the dimensions in a tensor shape can be expressed more
74 // simply as the size of the tensor.
75 for (auto& node : *optimized_graph->mutable_node()) {
76 if (!IsShape(node)) {
77 continue;
78 }
79 for (MutableGraphView::InputPort fanout :
80 graph.GetFanout(MutableGraphView::OutputPort(&node, 0))) {
81 if (fanout.node->op() != "Prod") {
82 continue;
83 }
84 if (fanout.node->attr().count("keep_dims") != 0 &&
85 fanout.node->attr().at("keep_dims").b()) {
86 // Keeping the reduced dimensions won't result in a scalar, so we
87 // can't rewrite the whole expression directly as a Size operation.
88 continue;
89 }
90 const MutableGraphView::OutputPort reduce_indices =
91 graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1));
92 if (!inferred_properties) {
93 // Infer properties lazily in case they are not needed.
94 TF_RETURN_IF_ERROR(
95 properties.InferStatically(/*assume_valid_feeds=*/false,
96 /*aggressive_shape_inference=*/false,
97 /*include_tensor_values=*/false));
98 inferred_properties = true;
99 }
100 const auto& prop =
101 properties.GetOutputProperties(reduce_indices.node->name());
102 const int prop_size = prop.size();
103 if (prop_size <= reduce_indices.port_id) {
104 continue;
105 }
106 const TensorShapeProto& reduction_indices_shape =
107 prop[reduce_indices.port_id].shape();
108 if (NumCoefficients(reduction_indices_shape) == 1) {
109 const auto& input_props = properties.GetInputProperties(node.name());
110 if (input_props.size() != 1) {
111 continue;
112 }
113 // Rewrite the reduction of the shape dimensions as a Size operation.
114 NodeDef size_node(*fanout.node);
115 const DataType type = input_props[0].dtype();
116 size_node.set_op("Size");
117 size_node.set_input(0, node.input(0));
118 size_node.set_input(1, AsControlDependency(node));
119 size_node.mutable_attr()->erase("Tidx");
120 size_node.mutable_attr()->erase("keep_dims");
121 (*size_node.mutable_attr())["out_type"] = fanout.node->attr().at("T");
122 (*size_node.mutable_attr())["T"].set_type(type);
123
124 // The corresponding Size kernel might not exist on the device where
125 // Prod was placed, so assign the Size kernel to the same device as
126 // the input.
127 size_node.set_device(node.device());
128
129 // In the unlikely even that "Size" is not registered on the input
130 // device, skip the optimization.
131 Status s = IsKernelRegisteredForNode(size_node);
132 if (!s.ok()) {
133 continue;
134 }
135
136 fanout.node->Swap(&size_node);
137 }
138 }
139 }
140 }
141 {
142 MutableGraphView graph(optimized_graph);
143 for (auto& node : *optimized_graph->mutable_node()) {
144 // Try to convert the ratio of 2 symbolic tensor sizes into a constant.
145 // This is possible whenever the symbolic dimensions in the numerator and
146 // denominator cancel each other.
147 if (node.op() == "Div") {
148 const MutableGraphView::OutputPort input1 =
149 graph.GetRegularFanin(MutableGraphView::InputPort(&node, 0));
150 const MutableGraphView::OutputPort input2 =
151 graph.GetRegularFanin(MutableGraphView::InputPort(&node, 1));
152 if (input1.node == nullptr || input2.node == nullptr) continue;
153 if (!IsSize(*input1.node) || !IsSize(*input2.node)) {
154 continue;
155 }
156 if (!inferred_properties) {
157 // Infer properties lazily in case they are not needed.
158 TF_RETURN_IF_ERROR(
159 properties.InferStatically(/*assume_valid_feeds=*/false,
160 /*aggressive_shape_inference=*/false,
161 /*include_tensor_values=*/false));
162 inferred_properties = true;
163 }
164 const auto& prop1 = properties.GetInputProperties(input1.node->name());
165 const auto& prop2 = properties.GetInputProperties(input2.node->name());
166 if (prop1.size() != 1 || prop2.size() != 1) {
167 continue;
168 }
169 const TensorShapeProto& shape1 = prop1[0].shape();
170 const TensorShapeProto& shape2 = prop2[0].shape();
171 int64_t result = ComputeSizeRatio(shape1, shape2);
172 if (result >= 0) {
173 // Replace div with constant.
174 node.set_op("Const");
175 DataType dtype = node.attr().at("T").type();
176 node.mutable_attr()->erase("T");
177 (*node.mutable_attr())["dtype"].set_type(dtype);
178 TensorProto* t = (*node.mutable_attr())["value"].mutable_tensor();
179 t->set_dtype(dtype);
180 *t->mutable_tensor_shape() = TensorShapeProto();
181 if (dtype == DT_INT32) {
182 t->add_int_val(result);
183 } else {
184 t->add_int64_val(result);
185 }
186 node.set_input(0, AsControlDependency(node.input(0)));
187 node.set_input(1, AsControlDependency(node.input(1)));
188 }
189 }
190 }
191 }
192 return OkStatus();
193 }
194
195 } // end namespace grappler
196 } // namespace tensorflow
197