1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h"
17
18 #include <string>
19
20 #include "absl/algorithm/container.h"
21 #include "tensorflow/compiler/jit/xla_cluster_util.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/util/dump_graph.h"
26
27 namespace tensorflow {
28
29 using se::port::StatusOr;
30
31 class CloneConstantsForBetterClusteringPassImpl {
32 public:
CloneConstantsForBetterClusteringPassImpl(Graph * graph)33 explicit CloneConstantsForBetterClusteringPassImpl(Graph* graph)
34 : graph_(graph), unique_name_counter_(0) {}
35 Status Run();
36
37 private:
38 Status CloneSmallHostConstantInputs(
39 const absl::flat_hash_set<string>& name_set, Node* n);
40 string GenerateUniqueName(const absl::flat_hash_set<string>& name_set,
41 absl::string_view prefix);
42 se::port::StatusOr<Node*> CloneNode(
43 const absl::flat_hash_set<string>& name_set, Node* n);
44
45 Graph* graph_;
46 int unique_name_counter_;
47 };
48
GenerateUniqueName(const absl::flat_hash_set<string> & name_set,absl::string_view prefix)49 string CloneConstantsForBetterClusteringPassImpl::GenerateUniqueName(
50 const absl::flat_hash_set<string>& name_set, absl::string_view prefix) {
51 string candidate;
52 do {
53 candidate = absl::StrCat(prefix, "/clone_", unique_name_counter_++);
54 } while (name_set.contains(candidate));
55 return candidate;
56 }
57
CloneNode(const absl::flat_hash_set<string> & name_set,Node * n)58 StatusOr<Node*> CloneConstantsForBetterClusteringPassImpl::CloneNode(
59 const absl::flat_hash_set<string>& name_set, Node* n) {
60 NodeDef new_in_def = n->def();
61 new_in_def.clear_input();
62 new_in_def.set_name(GenerateUniqueName(name_set, new_in_def.name()));
63 TF_ASSIGN_OR_RETURN(Node * new_in, graph_->AddNode(new_in_def));
64
65 for (const Edge* e : n->in_edges()) {
66 if (e->IsControlEdge()) {
67 graph_->AddControlEdge(e->src(), new_in);
68 } else {
69 graph_->AddEdge(e->src(), e->src_output(), new_in, e->dst_input());
70 }
71 }
72
73 new_in->set_assigned_device_name(n->assigned_device_name());
74 return new_in;
75 }
76
77 namespace {
IsConstantOnHost(Node * n)78 StatusOr<bool> IsConstantOnHost(Node* n) {
79 if (n->output_type(0) == DT_INT32) {
80 // TensorFlow always puts int32 tensors on the host.
81 return true;
82 }
83
84 DeviceNameUtils::ParsedName parsed;
85 TF_RET_CHECK(
86 DeviceNameUtils::ParseFullName(n->assigned_device_name(), &parsed));
87 return parsed.type == DEVICE_CPU;
88 }
89
IsConstantSmall(Node * n)90 StatusOr<bool> IsConstantSmall(Node* n) {
91 const TensorProto* proto = nullptr;
92 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto));
93
94 // TODO(sanjoy): It may make sense to combine this threshold with XLA's "large
95 // constant" threshold, if there is one.
96 const int kSmallTensorThreshold = 16;
97 int64_t total_elements = 1;
98 for (const auto& dim : proto->tensor_shape().dim()) {
99 if (dim.size() < 0) {
100 return errors::Internal("Unknown dimension size in constant tensor ",
101 n->name());
102 }
103 total_elements *= dim.size();
104 }
105 return total_elements < kSmallTensorThreshold;
106 }
107
108 // We only clone host constants for now since we want to avoid increasing memory
109 // pressure on GPUs.
IsSmallHostConstant(Node * n)110 StatusOr<bool> IsSmallHostConstant(Node* n) {
111 if (!n->IsConstant()) {
112 return false;
113 }
114
115 TF_ASSIGN_OR_RETURN(bool is_constant_on_host, IsConstantOnHost(n));
116 if (!is_constant_on_host) {
117 return false;
118 }
119
120 return IsConstantSmall(n);
121 }
122
IsInPlaceOp(absl::string_view op_name)123 bool IsInPlaceOp(absl::string_view op_name) {
124 return op_name == "InplaceUpdate" || op_name == "InplaceAdd" ||
125 op_name == "InplaceSub";
126 }
127 } // namespace
128
CloneSmallHostConstantInputs(const absl::flat_hash_set<string> & name_set,Node * n)129 Status CloneConstantsForBetterClusteringPassImpl::CloneSmallHostConstantInputs(
130 const absl::flat_hash_set<string>& name_set, Node* n) {
131 std::vector<const Edge*> in_edges;
132 // Get the edges and sort them so we clone in a deterministic order.
133 absl::c_copy(n->in_edges(), std::back_inserter(in_edges));
134 absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) {
135 return e1->id() < e2->id();
136 });
137 for (const Edge* e : in_edges) {
138 Node* input = e->src();
139 TF_ASSIGN_OR_RETURN(bool is_small_host_constant,
140 IsSmallHostConstant(input));
141 if (is_small_host_constant && input->out_edges().size() != 1) {
142 VLOG(2) << "Cloning small host constant " << input->name();
143 TF_ASSIGN_OR_RETURN(Node* const input_cloned, CloneNode(name_set, input));
144 if (e->IsControlEdge()) {
145 graph_->AddControlEdge(input_cloned, e->dst());
146 } else {
147 int dst_input = e->dst_input();
148 TF_RET_CHECK(e->src_output() == 0)
149 << "expected constant to have exactly one non-control output, but "
150 "found output index = "
151 << e->src_output();
152 graph_->RemoveEdge(e);
153 graph_->AddEdge(input_cloned, 0, n, dst_input);
154 }
155 }
156 }
157 return OkStatus();
158 }
159
Run()160 Status CloneConstantsForBetterClusteringPassImpl::Run() {
161 absl::flat_hash_set<string> name_set;
162 absl::c_transform(graph_->nodes(), std::inserter(name_set, name_set.begin()),
163 [](Node* n) { return n->name(); });
164 std::vector<Node*> nodes;
165 for (Node* n : graph_->nodes()) {
166 // We rely on the immutability of Tensors to safely clone Const operations.
167 // However, "in place" ops do not respect the immutability of Tensors so we
168 // avoid this transformation when such ops are present in the graph.
169 //
170 // In-place operations are problematic because they break the semantic
171 // illusion that tensorflow::Tensor instances are immutable. For instance
172 // if we have the following graph:
173 //
174 // digraph {
175 // SRC -> Const
176 // SRC -> I
177 // SRC -> V
178 // Const -> Identity
179 // Const -> InplaceAdd [label="x"]
180 // I -> InplaceAdd [label="i"]
181 // V -> InplaceAdd [label="v"]
182 // InplaceAdd -> Identity [style=dotted]
183 // }
184 //
185 // then the value produced by `Identity` is Const+I*V since InplaceAdd
186 // modifies the tensor in place. However, if we clone `Const` and turn the
187 // graph into:
188 //
189 // digraph {
190 // SRC -> "Const/clone_1"
191 // SRC -> "Const/clone_2"
192 // SRC -> I
193 // SRC -> V
194 // "Const/clone_1" -> Identity
195 // "Const/clone_2" -> InplaceAdd [label="x"]
196 // I -> InplaceAdd [label="i"]
197 // V -> InplaceAdd [label="v"]
198 // InplaceAdd -> Identity [style=dotted]
199 // }
200 //
201 // then `Identity` no longer produces Const+I*V because the InplaceAdd
202 // operation only modifies Const/clone_2 in place.
203
204 if (IsInPlaceOp(n->type_string())) {
205 return OkStatus();
206 }
207 nodes.push_back(n);
208 }
209
210 // Iterate over a copy of the nodes to avoid iterating over g->nodes() while
211 // creating more nodes.
212 for (Node* n : nodes) {
213 TF_RETURN_IF_ERROR(CloneSmallHostConstantInputs(name_set, n));
214 }
215 return OkStatus();
216 }
217
Run(const GraphOptimizationPassOptions & options)218 Status CloneConstantsForBetterClusteringPass::Run(
219 const GraphOptimizationPassOptions& options) {
220 if (GetGlobalJitLevelForGraph(options) == OptimizerOptions::OFF) {
221 return OkStatus();
222 }
223
224 Graph* g = options.graph->get();
225
226 if (VLOG_IS_ON(1)) {
227 DumpGraphToFile("before_clone_constants_for_better_clustering", *g);
228 }
229
230 TF_RETURN_IF_ERROR(CloneConstantsForBetterClusteringPassImpl{g}.Run());
231
232 if (VLOG_IS_ON(1)) {
233 DumpGraphToFile("after_clone_constants_for_better_clustering", *g);
234 }
235
236 return OkStatus();
237 }
238
239 } // namespace tensorflow
240