1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass_internal.h"
17
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/const_op.h"
21 #include "tensorflow/cc/ops/linalg_ops.h"
22 #include "tensorflow/cc/ops/math_ops.h"
23 #include "tensorflow/compiler/jit/node_matchers.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace tensorflow {
28 namespace {
29
30 using testing::matchers::Const;
31 using testing::matchers::Inputs;
32 using testing::matchers::Name;
33 using testing::matchers::NodeWith;
34 using testing::matchers::Op;
35 using testing::matchers::Out;
36
TEST(IntroduceFloatingPointJitterTest,SingleOutputFP32)37 TEST(IntroduceFloatingPointJitterTest, SingleOutputFP32) {
38 Scope root = Scope::NewRootScope().ExitOnError();
39
40 Output input_a = ops::Placeholder(root.WithOpName("input_a"), DT_FLOAT);
41 Output input_b = ops::Placeholder(root.WithOpName("input_b"), DT_FLOAT);
42
43 Output sigmoid_a = ops::Sigmoid(root.WithOpName("sigmoid_a"), input_a);
44 Output sigmoid_b = ops::Sigmoid(root.WithOpName("sigmoid_b"), input_b);
45
46 Output tanh_a = ops::Tanh(root.WithOpName("tanh_a"), sigmoid_a);
47 Output tanh_b = ops::Tanh(root.WithOpName("tanh_b"), sigmoid_b);
48
49 auto graph = std::make_unique<Graph>(OpRegistry::Global());
50 TF_ASSERT_OK(root.ToGraph(graph.get()));
51
52 std::vector<string> tensor_names;
53 tensor_names.push_back("sigmoid_a");
54 tensor_names.push_back("sigmoid_b");
55
56 TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
57 VLOG(1) << graph->ToGraphDefDebug().DebugString();
58
59 auto m_sigmoid_a = Out(NodeWith(Name("sigmoid_a")));
60 auto m_sigmoid_a_with_jitter =
61 NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_a));
62 auto m_tanh_a = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_a_with_jitter)));
63
64 auto m_sigmoid_b = Out(NodeWith(Name("sigmoid_b")));
65 auto m_sigmoid_b_with_jitter =
66 NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_b));
67 auto m_tanh_b = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_b_with_jitter)));
68
69 Node* tanh_a_transformed = testing::FindNodeByName(graph.get(), "tanh_a");
70 Node* tanh_b_transformed = testing::FindNodeByName(graph.get(), "tanh_b");
71
72 ASSERT_NE(tanh_a_transformed, nullptr);
73 ASSERT_NE(tanh_b_transformed, nullptr);
74
75 EXPECT_THAT(tanh_a_transformed, m_tanh_a);
76 EXPECT_THAT(tanh_b_transformed, m_tanh_b);
77 }
78
TEST(IntroduceFloatingPointJitterTest,TwoNodesOneUser)79 TEST(IntroduceFloatingPointJitterTest, TwoNodesOneUser) {
80 Scope root = Scope::NewRootScope().ExitOnError();
81
82 Output input_a = ops::Placeholder(root.WithOpName("input_a"), DT_FLOAT);
83 Output input_b = ops::Placeholder(root.WithOpName("input_b"), DT_FLOAT);
84
85 Output sigmoid_a = ops::Sigmoid(root.WithOpName("sigmoid_a"), input_a);
86 Output sigmoid_b = ops::Sigmoid(root.WithOpName("sigmoid_b"), input_b);
87
88 Output add = ops::Add(root.WithOpName("add"), sigmoid_a, sigmoid_b);
89
90 auto graph = std::make_unique<Graph>(OpRegistry::Global());
91 TF_ASSERT_OK(root.ToGraph(graph.get()));
92
93 std::vector<string> tensor_names;
94 tensor_names.push_back("sigmoid_a");
95 tensor_names.push_back("sigmoid_b");
96
97 TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
98 VLOG(1) << graph->ToGraphDefDebug().DebugString();
99
100 auto m_sigmoid_a = Out(NodeWith(Name("sigmoid_a")));
101 auto m_sigmoid_a_with_jitter =
102 NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_a));
103
104 auto m_sigmoid_b = Out(NodeWith(Name("sigmoid_b")));
105 auto m_sigmoid_b_with_jitter =
106 NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_b));
107
108 auto m_add = NodeWith(Op("Add"), Inputs(Out(m_sigmoid_a_with_jitter),
109 Out(m_sigmoid_b_with_jitter)));
110
111 Node* add_transformed = testing::FindNodeByName(graph.get(), "add");
112
113 ASSERT_NE(add_transformed, nullptr);
114
115 EXPECT_THAT(add_transformed, m_add);
116 }
117
TEST(IntroduceFloatingPointJitterTest,NotFP32)118 TEST(IntroduceFloatingPointJitterTest, NotFP32) {
119 Scope root = Scope::NewRootScope().ExitOnError();
120
121 Output input = ops::Placeholder(root.WithOpName("input"), DT_HALF);
122
123 Output sigmoid = ops::Sigmoid(root.WithOpName("sigmoid"), input);
124
125 Output tanh = ops::Tanh(root.WithOpName("tanh"), sigmoid);
126
127 auto graph = std::make_unique<Graph>(OpRegistry::Global());
128 TF_ASSERT_OK(root.ToGraph(graph.get()));
129
130 std::vector<string> tensor_names;
131 tensor_names.push_back("sigmoid");
132
133 TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
134 VLOG(1) << graph->ToGraphDefDebug().DebugString();
135
136 auto m_sigmoid = Out(NodeWith(Name("sigmoid")));
137 auto m_sigmoid_with_jitter =
138 NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_sigmoid));
139 auto m_tanh = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_with_jitter)));
140
141 Node* tanh_transformed = testing::FindNodeByName(graph.get(), "tanh");
142
143 ASSERT_NE(tanh_transformed, nullptr);
144
145 EXPECT_THAT(tanh_transformed, m_tanh);
146 }
147
TEST(IntroduceFloatingPointJitterTest,MultiOutput)148 TEST(IntroduceFloatingPointJitterTest, MultiOutput) {
149 Scope root = Scope::NewRootScope().ExitOnError();
150
151 Output input = ops::Placeholder(root.WithOpName("input"), DT_HALF);
152
153 ops::Svd svd(root.WithOpName("svd"), input);
154
155 Output tanh_s = ops::Tanh(root.WithOpName("tanh_s"), svd.s);
156 Output tanh_u = ops::Tanh(root.WithOpName("tanh_u"), svd.u);
157 Output tanh_v = ops::Tanh(root.WithOpName("tanh_v"), svd.v);
158
159 auto graph = std::make_unique<Graph>(OpRegistry::Global());
160 TF_ASSERT_OK(root.ToGraph(graph.get()));
161
162 std::vector<string> tensor_names;
163 tensor_names.push_back("svd:0");
164 tensor_names.push_back("svd:2");
165
166 TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
167 VLOG(1) << graph->ToGraphDefDebug().DebugString();
168
169 auto m_svd_s = Out(0, NodeWith(Name("svd")));
170 auto m_svd_s_with_jitter = Out(
171 NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_svd_s)));
172
173 auto m_svd_u = Out(1, NodeWith(Name("svd")));
174
175 auto m_svd_v = Out(2, NodeWith(Name("svd")));
176 auto m_svd_v_with_jitter = Out(
177 NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_svd_v)));
178
179 auto m_tanh_s = NodeWith(Op("Tanh"), Inputs(m_svd_s_with_jitter));
180 auto m_tanh_u = NodeWith(Op("Tanh"), Inputs(m_svd_u));
181 auto m_tanh_v = NodeWith(Op("Tanh"), Inputs(m_svd_v_with_jitter));
182
183 Node* tanh_s_transformed = testing::FindNodeByName(graph.get(), "tanh_s");
184 ASSERT_NE(tanh_s_transformed, nullptr);
185
186 Node* tanh_u_transformed = testing::FindNodeByName(graph.get(), "tanh_u");
187 ASSERT_NE(tanh_u_transformed, nullptr);
188
189 Node* tanh_v_transformed = testing::FindNodeByName(graph.get(), "tanh_v");
190 ASSERT_NE(tanh_v_transformed, nullptr);
191
192 EXPECT_THAT(tanh_s_transformed, m_tanh_s);
193 EXPECT_THAT(tanh_u_transformed, m_tanh_u);
194 EXPECT_THAT(tanh_v_transformed, m_tanh_v);
195 }
196 } // namespace
197 } // namespace tensorflow
198