1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils/frame.h"
17 #include <deque>
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/grappler/graph_view.h"
21 #include "tensorflow/core/grappler/op_types.h"
22 #include "tensorflow/core/lib/core/errors.h"
23
24 namespace tensorflow {
25 namespace grappler {
26
27 namespace {} // namespace
28
InferFromGraphView(const GraphView & graph_view)29 Status FrameView::InferFromGraphView(const GraphView& graph_view) {
30 if (is_inferred_) {
31 return errors::Internal("FrameView was already inferred from the graph");
32 }
33 is_inferred_ = true;
34
35 std::deque<const NodeDef*> ready_nodes;
36
37 // All nodes without inputs are automatically added to the ready queue.
38 for (const NodeDef& node : graph_view.graph()->node()) {
39 if (node.input_size() == 0) {
40 ready_nodes.push_back(&node);
41 node_to_frames_[&node] = node_has_no_frames_;
42 }
43 }
44
45 // We assign unique int id to each frame, and use this map to track what
46 // frames we've already seen in the graph.
47 absl::flat_hash_map<string, int> frame_name_to_id;
48
49 while (!ready_nodes.empty()) {
50 const NodeDef* ready_node = ready_nodes.front();
51
52 absl::flat_hash_set<GraphView::InputPort> fanouts =
53 graph_view.GetFanouts(*ready_node, /*include_controlled_nodes=*/true);
54
55 for (const GraphView::InputPort& fanout : fanouts) {
56 if (node_to_frames_.count(fanout.node) < 1) {
57 // If we have never seen this node before, we add all frames from the
58 // incoming node (and pop/push frames if coming from Exit/Enter nodes).
59 std::vector<int> frame_ids = node_to_frames_[ready_node];
60
61 if (IsExit(*ready_node)) {
62 frame_ids.pop_back();
63 }
64
65 if (IsEnter(*fanout.node)) {
66 const AttrValue* frame_name_attr =
67 AttrSlice(*fanout.node).Find("frame_name");
68
69 if (!frame_name_attr) {
70 return errors::InvalidArgument(
71 "Missing frame name for the Enter node: ",
72 SummarizeNodeDef(*fanout.node));
73 }
74
75 absl::string_view frame_name = frame_name_attr->s();
76 int frame_id;
77
78 if (frame_name_to_id.count(frame_name)) {
79 frame_id = frame_name_to_id[frame_name];
80 } else {
81 frame_id = static_cast<int>(frame_name_to_id.size());
82 frame_name_to_id[frame_name] = frame_id;
83 }
84
85 frame_ids.push_back(frame_id);
86 }
87
88 ready_nodes.push_back(fanout.node);
89 node_to_frames_[fanout.node] = std::move(frame_ids);
90
91 } else {
92 // If we've already seen this node before, we need to make sure that
93 // graph is correct and same nodes doesn't have incoming edges with
94 // conflicting frames (all inputs must be produces in the same frame).
95
96 std::vector<int> frame_ids_fanout = node_to_frames_[fanout.node];
97 std::vector<int> frame_ids_node = node_to_frames_[ready_node];
98
99 if (IsEnter(*fanout.node)) {
100 frame_ids_fanout.pop_back();
101 }
102 if (IsExit(*ready_node)) {
103 frame_ids_node.pop_back();
104 }
105
106 if (frame_ids_node != frame_ids_fanout) {
107 return errors::InvalidArgument(
108 "Invalid graph: Frame ids for node ", ready_node->name(),
109 " does not match frame ids for it's fanout ",
110 fanout.node->name());
111 }
112 }
113 }
114
115 ready_nodes.pop_front();
116 }
117
118 num_frames_ = static_cast<int>(frame_name_to_id.size());
119 return Status::OK();
120 }
121
InferFromGraph(const GraphDef & graph)122 Status FrameView::InferFromGraph(const GraphDef& graph) {
123 return InferFromGraphView(GraphView(&graph));
124 }
125
Frames(const NodeDef & node) const126 const std::vector<int>& FrameView::Frames(const NodeDef& node) const {
127 DCHECK(is_inferred_) << "FrameView is not initialized";
128 auto frames = node_to_frames_.find(&node);
129 if (frames == node_to_frames_.end()) {
130 LOG(WARNING) << "Node doesn't belong to the graph used for initialization";
131 return node_has_no_frames_;
132 } else {
133 return frames->second;
134 }
135 }
136
IsInFrame(const NodeDef & node) const137 bool FrameView::IsInFrame(const NodeDef& node) const {
138 return !Frames(node).empty();
139 }
140
141 } // namespace grappler
142 } // namespace tensorflow
143