1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/inspecting_placer.h"
16
17 #include <memory>
18 #include <unordered_map>
19 #include <vector>
20
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/core/common_runtime/colocation_graph.h"
23 #include "tensorflow/core/common_runtime/device.h"
24 #include "tensorflow/core/common_runtime/function_body.h"
25 #include "tensorflow/core/common_runtime/function_def_utils.h"
26 #include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/graph/graph_node_util.h"
31 #include "tensorflow/core/lib/core/errors.h"
32
33 namespace tensorflow {
34
DebugString() const35 string IOColocationGroups::DebugString() const {
36 std::unordered_map<int, std::vector<string>> group_members;
37 for (int arg_index = 0; arg_index < input_groups.size(); ++arg_index) {
38 int group_id = input_groups[arg_index];
39 group_members[group_id].push_back(strings::StrCat("i:", arg_index));
40 }
41 for (int ret_index = 0; ret_index < output_groups.size(); ++ret_index) {
42 int group_id = output_groups[ret_index];
43 group_members[group_id].push_back(strings::StrCat("o:", ret_index));
44 }
45
46 std::vector<string> group_strings;
47 for (const auto& it : group_members) {
48 int group_id = it.first;
49 const std::vector<string>& members = it.second;
50 const PossibleDevices& devices = group_devices[group_id];
51 group_strings.push_back(strings::StrCat(
52 "Group(", group_id, " members = [", absl::StrJoin(members, ", "),
53 "] requested_device_name = \"",
54 DeviceNameUtils::ParsedNameToString(devices.requested_device_name),
55 "\" resource_device_name = \"",
56 DeviceNameUtils::ParsedNameToString(devices.resource_device_name),
57 "\" device_types = [",
58 absl::StrJoin(
59 devices.device_types, ", ",
60 [](string* out, const std::pair<DeviceType, int32>& type_and_pref) {
61 out->append(DeviceTypeString(type_and_pref.first));
62 }),
63 "])"));
64 }
65
66 return absl::StrJoin(group_strings, "\n\t");
67 }
68
69 // Utility class for constructing IOColocationGroups from a ColocationGraph.
70 class ColocationGraphToIOColocationGroups {
71 public:
72 // colocation_graph is mutable because finding root nodes can update
73 // parent pointers. It is not modified otherwise.
ColocationGraphToIOColocationGroups(ColocationGraph * colocation_graph)74 explicit ColocationGraphToIOColocationGroups(
75 ColocationGraph* colocation_graph)
76 : colocation_graph_(colocation_graph), next_group_id_(0) {}
77
AssignGroups(const gtl::InlinedVector<Node *,4> & nodes,std::vector<int> * groups)78 void AssignGroups(const gtl::InlinedVector<Node*, 4>& nodes,
79 std::vector<int>* groups) {
80 for (int i = 0; i < nodes.size(); ++i) {
81 int root_id = colocation_graph_->FindAndUpdateRoot(nodes[i]->id());
82 const auto& it = group_ids_.find(root_id);
83 int assigned_group_id;
84 if (it == group_ids_.end()) {
85 group_ids_[root_id] = next_group_id_;
86 assigned_group_id = next_group_id_;
87 ++next_group_id_;
88 } else {
89 assigned_group_id = it->second;
90 }
91 groups->push_back(assigned_group_id);
92 }
93 }
94
FillGroups(std::vector<PossibleDevices> * group_devices)95 Status FillGroups(std::vector<PossibleDevices>* group_devices) {
96 group_devices->resize(group_ids_.size());
97 for (const auto& it : group_ids_) {
98 int assigned_group_id = it.second;
99 PossibleDevices& possible_devices = (*group_devices)[assigned_group_id];
100 const Member& member = colocation_graph_->members()[it.first];
101 TF_RETURN_IF_ERROR(member.FillPossibleDevices(&possible_devices));
102 }
103 return Status::OK();
104 }
105
106 private:
107 ColocationGraph* colocation_graph_;
108 // Allocated group ids: collocation_graph root id -> allocated group id.
109 std::unordered_map<int, int> group_ids_;
110 int next_group_id_;
111 };
112
InspectingPlacer(const FunctionStack & stack,const FunctionLibraryDefinition * flib_def,const DeviceSet * device_set,const Device * default_device,bool allow_soft_placement,bool log_device_placement)113 InspectingPlacer::InspectingPlacer(const FunctionStack& stack,
114 const FunctionLibraryDefinition* flib_def,
115 const DeviceSet* device_set,
116 const Device* default_device,
117 bool allow_soft_placement,
118 bool log_device_placement)
119 : stack_(stack),
120 flib_def_(*flib_def),
121 device_set_(*device_set),
122 default_device_(default_device),
123 allow_soft_placement_(allow_soft_placement),
124 log_device_placement_(log_device_placement) {}
125
ComputeIOColocationGroups(const Node & node,IOColocationGroups * groups)126 Status InspectingPlacer::ComputeIOColocationGroups(const Node& node,
127 IOColocationGroups* groups) {
128 const FunctionDef* fdef;
129 NameAttrList func;
130 TF_RETURN_IF_ERROR(GetFunctionDefAndAttrs(flib_def_, node, &fdef, &func));
131 std::unique_ptr<FunctionBody> fbody;
132
133 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
134 &flib_def_, &fbody));
135
136 TF_RETURN_IF_ERROR(
137 IsolatePlacerInspectionRequiredOps(flib_def_, fbody->graph));
138 if (stack_.HasFunction(func.name())) {
139 return errors::Unimplemented(
140 "Recursive function calls are not supported. Node ",
141 FormatNodeForError(node), " inside the body of ",
142 errors::FormatFunctionForError(stack_.current_function_name()),
143 " calls function ", errors::FormatFunctionForError(func.name()),
144 " which is already present in the call stack:\n ",
145 stack_.FormatForError());
146 }
147
148 ColocationGraph colocation_graph(
149 fbody->graph, stack_.Push(&node, func.name()), &flib_def_, &device_set_,
150 default_device_, allow_soft_placement_, log_device_placement_);
151 TF_RETURN_IF_ERROR(colocation_graph.Initialize());
152
153 ColocationGraphToIOColocationGroups converter(&colocation_graph);
154 converter.AssignGroups(fbody->arg_nodes, &groups->input_groups);
155 converter.AssignGroups(fbody->ret_nodes, &groups->output_groups);
156 TF_RETURN_IF_ERROR(converter.FillGroups(&groups->group_devices));
157 return Status::OK();
158 }
159
160 } // namespace tensorflow
161