• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass_internal.h"
17 
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/const_op.h"
21 #include "tensorflow/cc/ops/linalg_ops.h"
22 #include "tensorflow/cc/ops/math_ops.h"
23 #include "tensorflow/compiler/jit/node_matchers.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace tensorflow {
28 namespace {
29 
30 using testing::matchers::Const;
31 using testing::matchers::Inputs;
32 using testing::matchers::Name;
33 using testing::matchers::NodeWith;
34 using testing::matchers::Op;
35 using testing::matchers::Out;
36 
TEST(IntroduceFloatingPointJitterTest,SingleOutputFP32)37 TEST(IntroduceFloatingPointJitterTest, SingleOutputFP32) {
38   Scope root = Scope::NewRootScope().ExitOnError();
39 
40   Output input_a = ops::Placeholder(root.WithOpName("input_a"), DT_FLOAT);
41   Output input_b = ops::Placeholder(root.WithOpName("input_b"), DT_FLOAT);
42 
43   Output sigmoid_a = ops::Sigmoid(root.WithOpName("sigmoid_a"), input_a);
44   Output sigmoid_b = ops::Sigmoid(root.WithOpName("sigmoid_b"), input_b);
45 
46   Output tanh_a = ops::Tanh(root.WithOpName("tanh_a"), sigmoid_a);
47   Output tanh_b = ops::Tanh(root.WithOpName("tanh_b"), sigmoid_b);
48 
49   auto graph = std::make_unique<Graph>(OpRegistry::Global());
50   TF_ASSERT_OK(root.ToGraph(graph.get()));
51 
52   std::vector<string> tensor_names;
53   tensor_names.push_back("sigmoid_a");
54   tensor_names.push_back("sigmoid_b");
55 
56   TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
57   VLOG(1) << graph->ToGraphDefDebug().DebugString();
58 
59   auto m_sigmoid_a = Out(NodeWith(Name("sigmoid_a")));
60   auto m_sigmoid_a_with_jitter =
61       NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_a));
62   auto m_tanh_a = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_a_with_jitter)));
63 
64   auto m_sigmoid_b = Out(NodeWith(Name("sigmoid_b")));
65   auto m_sigmoid_b_with_jitter =
66       NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_b));
67   auto m_tanh_b = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_b_with_jitter)));
68 
69   Node* tanh_a_transformed = testing::FindNodeByName(graph.get(), "tanh_a");
70   Node* tanh_b_transformed = testing::FindNodeByName(graph.get(), "tanh_b");
71 
72   ASSERT_NE(tanh_a_transformed, nullptr);
73   ASSERT_NE(tanh_b_transformed, nullptr);
74 
75   EXPECT_THAT(tanh_a_transformed, m_tanh_a);
76   EXPECT_THAT(tanh_b_transformed, m_tanh_b);
77 }
78 
TEST(IntroduceFloatingPointJitterTest,TwoNodesOneUser)79 TEST(IntroduceFloatingPointJitterTest, TwoNodesOneUser) {
80   Scope root = Scope::NewRootScope().ExitOnError();
81 
82   Output input_a = ops::Placeholder(root.WithOpName("input_a"), DT_FLOAT);
83   Output input_b = ops::Placeholder(root.WithOpName("input_b"), DT_FLOAT);
84 
85   Output sigmoid_a = ops::Sigmoid(root.WithOpName("sigmoid_a"), input_a);
86   Output sigmoid_b = ops::Sigmoid(root.WithOpName("sigmoid_b"), input_b);
87 
88   Output add = ops::Add(root.WithOpName("add"), sigmoid_a, sigmoid_b);
89 
90   auto graph = std::make_unique<Graph>(OpRegistry::Global());
91   TF_ASSERT_OK(root.ToGraph(graph.get()));
92 
93   std::vector<string> tensor_names;
94   tensor_names.push_back("sigmoid_a");
95   tensor_names.push_back("sigmoid_b");
96 
97   TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
98   VLOG(1) << graph->ToGraphDefDebug().DebugString();
99 
100   auto m_sigmoid_a = Out(NodeWith(Name("sigmoid_a")));
101   auto m_sigmoid_a_with_jitter =
102       NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_a));
103 
104   auto m_sigmoid_b = Out(NodeWith(Name("sigmoid_b")));
105   auto m_sigmoid_b_with_jitter =
106       NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_b));
107 
108   auto m_add = NodeWith(Op("Add"), Inputs(Out(m_sigmoid_a_with_jitter),
109                                           Out(m_sigmoid_b_with_jitter)));
110 
111   Node* add_transformed = testing::FindNodeByName(graph.get(), "add");
112 
113   ASSERT_NE(add_transformed, nullptr);
114 
115   EXPECT_THAT(add_transformed, m_add);
116 }
117 
TEST(IntroduceFloatingPointJitterTest,NotFP32)118 TEST(IntroduceFloatingPointJitterTest, NotFP32) {
119   Scope root = Scope::NewRootScope().ExitOnError();
120 
121   Output input = ops::Placeholder(root.WithOpName("input"), DT_HALF);
122 
123   Output sigmoid = ops::Sigmoid(root.WithOpName("sigmoid"), input);
124 
125   Output tanh = ops::Tanh(root.WithOpName("tanh"), sigmoid);
126 
127   auto graph = std::make_unique<Graph>(OpRegistry::Global());
128   TF_ASSERT_OK(root.ToGraph(graph.get()));
129 
130   std::vector<string> tensor_names;
131   tensor_names.push_back("sigmoid");
132 
133   TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
134   VLOG(1) << graph->ToGraphDefDebug().DebugString();
135 
136   auto m_sigmoid = Out(NodeWith(Name("sigmoid")));
137   auto m_sigmoid_with_jitter =
138       NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_sigmoid));
139   auto m_tanh = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_with_jitter)));
140 
141   Node* tanh_transformed = testing::FindNodeByName(graph.get(), "tanh");
142 
143   ASSERT_NE(tanh_transformed, nullptr);
144 
145   EXPECT_THAT(tanh_transformed, m_tanh);
146 }
147 
TEST(IntroduceFloatingPointJitterTest,MultiOutput)148 TEST(IntroduceFloatingPointJitterTest, MultiOutput) {
149   Scope root = Scope::NewRootScope().ExitOnError();
150 
151   Output input = ops::Placeholder(root.WithOpName("input"), DT_HALF);
152 
153   ops::Svd svd(root.WithOpName("svd"), input);
154 
155   Output tanh_s = ops::Tanh(root.WithOpName("tanh_s"), svd.s);
156   Output tanh_u = ops::Tanh(root.WithOpName("tanh_u"), svd.u);
157   Output tanh_v = ops::Tanh(root.WithOpName("tanh_v"), svd.v);
158 
159   auto graph = std::make_unique<Graph>(OpRegistry::Global());
160   TF_ASSERT_OK(root.ToGraph(graph.get()));
161 
162   std::vector<string> tensor_names;
163   tensor_names.push_back("svd:0");
164   tensor_names.push_back("svd:2");
165 
166   TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
167   VLOG(1) << graph->ToGraphDefDebug().DebugString();
168 
169   auto m_svd_s = Out(0, NodeWith(Name("svd")));
170   auto m_svd_s_with_jitter = Out(
171       NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_svd_s)));
172 
173   auto m_svd_u = Out(1, NodeWith(Name("svd")));
174 
175   auto m_svd_v = Out(2, NodeWith(Name("svd")));
176   auto m_svd_v_with_jitter = Out(
177       NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_svd_v)));
178 
179   auto m_tanh_s = NodeWith(Op("Tanh"), Inputs(m_svd_s_with_jitter));
180   auto m_tanh_u = NodeWith(Op("Tanh"), Inputs(m_svd_u));
181   auto m_tanh_v = NodeWith(Op("Tanh"), Inputs(m_svd_v_with_jitter));
182 
183   Node* tanh_s_transformed = testing::FindNodeByName(graph.get(), "tanh_s");
184   ASSERT_NE(tanh_s_transformed, nullptr);
185 
186   Node* tanh_u_transformed = testing::FindNodeByName(graph.get(), "tanh_u");
187   ASSERT_NE(tanh_u_transformed, nullptr);
188 
189   Node* tanh_v_transformed = testing::FindNodeByName(graph.get(), "tanh_v");
190   ASSERT_NE(tanh_v_transformed, nullptr);
191 
192   EXPECT_THAT(tanh_s_transformed, m_tanh_s);
193   EXPECT_THAT(tanh_u_transformed, m_tanh_u);
194   EXPECT_THAT(tanh_v_transformed, m_tanh_v);
195 }
196 }  // namespace
197 }  // namespace tensorflow
198