• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h"
17 
18 #include <string>
19 
20 #include "absl/algorithm/container.h"
21 #include "tensorflow/compiler/jit/xla_cluster_util.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/util/dump_graph.h"
26 
27 namespace tensorflow {
28 
29 using se::port::StatusOr;
30 
31 class CloneConstantsForBetterClusteringPassImpl {
32  public:
CloneConstantsForBetterClusteringPassImpl(Graph * graph)33   explicit CloneConstantsForBetterClusteringPassImpl(Graph* graph)
34       : graph_(graph), unique_name_counter_(0) {}
35   Status Run();
36 
37  private:
38   Status CloneSmallHostConstantInputs(
39       const absl::flat_hash_set<string>& name_set, Node* n);
40   string GenerateUniqueName(const absl::flat_hash_set<string>& name_set,
41                             absl::string_view prefix);
42   se::port::StatusOr<Node*> CloneNode(
43       const absl::flat_hash_set<string>& name_set, Node* n);
44 
45   Graph* graph_;
46   int unique_name_counter_;
47 };
48 
GenerateUniqueName(const absl::flat_hash_set<string> & name_set,absl::string_view prefix)49 string CloneConstantsForBetterClusteringPassImpl::GenerateUniqueName(
50     const absl::flat_hash_set<string>& name_set, absl::string_view prefix) {
51   string candidate;
52   do {
53     candidate = absl::StrCat(prefix, "/clone_", unique_name_counter_++);
54   } while (name_set.contains(candidate));
55   return candidate;
56 }
57 
CloneNode(const absl::flat_hash_set<string> & name_set,Node * n)58 StatusOr<Node*> CloneConstantsForBetterClusteringPassImpl::CloneNode(
59     const absl::flat_hash_set<string>& name_set, Node* n) {
60   NodeDef new_in_def = n->def();
61   new_in_def.clear_input();
62   new_in_def.set_name(GenerateUniqueName(name_set, new_in_def.name()));
63   TF_ASSIGN_OR_RETURN(Node * new_in, graph_->AddNode(new_in_def));
64 
65   for (const Edge* e : n->in_edges()) {
66     if (e->IsControlEdge()) {
67       graph_->AddControlEdge(e->src(), new_in);
68     } else {
69       graph_->AddEdge(e->src(), e->src_output(), new_in, e->dst_input());
70     }
71   }
72 
73   new_in->set_assigned_device_name(n->assigned_device_name());
74   return new_in;
75 }
76 
77 namespace {
IsConstantOnHost(Node * n)78 StatusOr<bool> IsConstantOnHost(Node* n) {
79   if (n->output_type(0) == DT_INT32) {
80     // TensorFlow always puts int32 tensors on the host.
81     return true;
82   }
83 
84   DeviceNameUtils::ParsedName parsed;
85   TF_RET_CHECK(
86       DeviceNameUtils::ParseFullName(n->assigned_device_name(), &parsed));
87   return parsed.type == DEVICE_CPU;
88 }
89 
IsConstantSmall(Node * n)90 StatusOr<bool> IsConstantSmall(Node* n) {
91   const TensorProto* proto = nullptr;
92   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto));
93 
94   // TODO(sanjoy): It may make sense to combine this threshold with XLA's "large
95   // constant" threshold, if there is one.
96   const int kSmallTensorThreshold = 16;
97   int64_t total_elements = 1;
98   for (const auto& dim : proto->tensor_shape().dim()) {
99     if (dim.size() < 0) {
100       return errors::Internal("Unknown dimension size in constant tensor ",
101                               n->name());
102     }
103     total_elements *= dim.size();
104   }
105   return total_elements < kSmallTensorThreshold;
106 }
107 
108 // We only clone host constants for now since we want to avoid increasing memory
109 // pressure on GPUs.
IsSmallHostConstant(Node * n)110 StatusOr<bool> IsSmallHostConstant(Node* n) {
111   if (!n->IsConstant()) {
112     return false;
113   }
114 
115   TF_ASSIGN_OR_RETURN(bool is_constant_on_host, IsConstantOnHost(n));
116   if (!is_constant_on_host) {
117     return false;
118   }
119 
120   return IsConstantSmall(n);
121 }
122 
IsInPlaceOp(absl::string_view op_name)123 bool IsInPlaceOp(absl::string_view op_name) {
124   return op_name == "InplaceUpdate" || op_name == "InplaceAdd" ||
125          op_name == "InplaceSub";
126 }
127 }  // namespace
128 
CloneSmallHostConstantInputs(const absl::flat_hash_set<string> & name_set,Node * n)129 Status CloneConstantsForBetterClusteringPassImpl::CloneSmallHostConstantInputs(
130     const absl::flat_hash_set<string>& name_set, Node* n) {
131   std::vector<const Edge*> in_edges;
132   // Get the edges and sort them so we clone in a deterministic order.
133   absl::c_copy(n->in_edges(), std::back_inserter(in_edges));
134   absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) {
135     return e1->id() < e2->id();
136   });
137   for (const Edge* e : in_edges) {
138     Node* input = e->src();
139     TF_ASSIGN_OR_RETURN(bool is_small_host_constant,
140                         IsSmallHostConstant(input));
141     if (is_small_host_constant && input->out_edges().size() != 1) {
142       VLOG(2) << "Cloning small host constant " << input->name();
143       TF_ASSIGN_OR_RETURN(Node* const input_cloned, CloneNode(name_set, input));
144       if (e->IsControlEdge()) {
145         graph_->AddControlEdge(input_cloned, e->dst());
146       } else {
147         int dst_input = e->dst_input();
148         TF_RET_CHECK(e->src_output() == 0)
149             << "expected constant to have exactly one non-control output, but "
150                "found output index = "
151             << e->src_output();
152         graph_->RemoveEdge(e);
153         graph_->AddEdge(input_cloned, 0, n, dst_input);
154       }
155     }
156   }
157   return OkStatus();
158 }
159 
Run()160 Status CloneConstantsForBetterClusteringPassImpl::Run() {
161   absl::flat_hash_set<string> name_set;
162   absl::c_transform(graph_->nodes(), std::inserter(name_set, name_set.begin()),
163                     [](Node* n) { return n->name(); });
164   std::vector<Node*> nodes;
165   for (Node* n : graph_->nodes()) {
166     // We rely on the immutability of Tensors to safely clone Const operations.
167     // However, "in place" ops do not respect the immutability of Tensors so we
168     // avoid this transformation when such ops are present in the graph.
169     //
170     // In-place operations are problematic because they break the semantic
171     // illusion that tensorflow::Tensor instances are immutable.  For instance
172     // if we have the following graph:
173     //
174     // digraph {
175     //   SRC -> Const
176     //   SRC -> I
177     //   SRC -> V
178     //   Const -> Identity
179     //   Const -> InplaceAdd [label="x"]
180     //   I -> InplaceAdd [label="i"]
181     //   V -> InplaceAdd [label="v"]
182     //   InplaceAdd -> Identity [style=dotted]
183     // }
184     //
185     // then the value produced by `Identity` is Const+I*V since InplaceAdd
186     // modifies the tensor in place.  However, if we clone `Const` and turn the
187     // graph into:
188     //
189     // digraph {
190     //   SRC -> "Const/clone_1"
191     //   SRC -> "Const/clone_2"
192     //   SRC -> I
193     //   SRC -> V
194     //   "Const/clone_1" -> Identity
195     //   "Const/clone_2" -> InplaceAdd [label="x"]
196     //   I -> InplaceAdd [label="i"]
197     //   V -> InplaceAdd [label="v"]
198     //   InplaceAdd -> Identity [style=dotted]
199     // }
200     //
201     // then `Identity` no longer produces Const+I*V because the InplaceAdd
202     // operation only modifies Const/clone_2 in place.
203 
204     if (IsInPlaceOp(n->type_string())) {
205       return OkStatus();
206     }
207     nodes.push_back(n);
208   }
209 
210   // Iterate over a copy of the nodes to avoid iterating over g->nodes() while
211   // creating more nodes.
212   for (Node* n : nodes) {
213     TF_RETURN_IF_ERROR(CloneSmallHostConstantInputs(name_set, n));
214   }
215   return OkStatus();
216 }
217 
Run(const GraphOptimizationPassOptions & options)218 Status CloneConstantsForBetterClusteringPass::Run(
219     const GraphOptimizationPassOptions& options) {
220   if (GetGlobalJitLevelForGraph(options) == OptimizerOptions::OFF) {
221     return OkStatus();
222   }
223 
224   Graph* g = options.graph->get();
225 
226   if (VLOG_IS_ON(1)) {
227     DumpGraphToFile("before_clone_constants_for_better_clustering", *g);
228   }
229 
230   TF_RETURN_IF_ERROR(CloneConstantsForBetterClusteringPassImpl{g}.Run());
231 
232   if (VLOG_IS_ON(1)) {
233     DumpGraphToFile("after_clone_constants_for_better_clustering", *g);
234   }
235 
236   return OkStatus();
237 }
238 
239 }  // namespace tensorflow
240