1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || INTEL_MKL
17
18 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
19
20 #include <utility>
21 #include <vector>
22
23 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
24 #include "tensorflow/cc/ops/list_ops.h"
25 #include "tensorflow/cc/ops/math_ops.h"
26 #include "tensorflow/cc/ops/standard_ops.h"
27 #include "tensorflow/core/framework/function_testlib.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/tensor_testutil.h"
30 #include "tensorflow/core/grappler/clusters/single_machine.h"
31 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
32 #include "tensorflow/core/grappler/devices.h"
33 #include "tensorflow/core/grappler/graph_view.h"
34 #include "tensorflow/core/grappler/utils/grappler_test.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/lib/random/random.h"
37 #include "tensorflow/core/util/util.h"
38
39 // TODO(benbarsdell): Improve the numerical checks in these tests. The tests
40 // were originally written only to check the graph coloring, so the graphs do
41 // not have particularly realistic numerical behavior.
42
43 namespace tensorflow {
44 namespace grappler {
45 namespace {
46
47 template <DataType DTYPE>
GenerateIdentityMatrix(int64_t height,int64_t width)48 Tensor GenerateIdentityMatrix(int64_t height, int64_t width) {
49 typedef typename EnumToDataType<DTYPE>::Type T;
50 Tensor tensor(DTYPE, TensorShape{height, width});
51 for (int64_t i = 0; i < height; ++i) {
52 for (int64_t j = 0; j < width; ++j) {
53 tensor.matrix<T>()(i, j) = i == j;
54 }
55 }
56 return tensor;
57 }
58
59 template <DataType DTYPE>
GenerateRandomTensorInRange(const TensorShape & shape,double minval,double maxval)60 Tensor GenerateRandomTensorInRange(const TensorShape& shape, double minval,
61 double maxval) {
62 typedef typename EnumToDataType<DTYPE>::Type T;
63 Tensor tensor(DTYPE, shape);
64 for (auto i = 0; i < tensor.NumElements(); i++)
65 tensor.flat<T>()(i) =
66 (random::New64() % 65536 / 65536.0) * (maxval - minval) + minval;
67 return tensor;
68 }
69
VerifyGraphsEquivalent(const GraphDef & original_graph,const GraphDef & optimized_graph,const string & func)70 void VerifyGraphsEquivalent(const GraphDef& original_graph,
71 const GraphDef& optimized_graph,
72 const string& func) {
73 EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
74 GraphView optimized_view(&optimized_graph);
75 for (int i = 0; i < original_graph.node_size(); ++i) {
76 const NodeDef& original = original_graph.node(i);
77 const NodeDef& optimized = *optimized_view.GetNode(original.name());
78 EXPECT_EQ(original.name(), optimized.name()) << func;
79 EXPECT_EQ(original.op(), optimized.op()) << func;
80 EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
81 if (original.input_size() == optimized.input_size()) {
82 for (int j = 0; j < original.input_size(); ++j) {
83 EXPECT_EQ(original.input(j), optimized.input(j)) << func;
84 }
85 }
86 }
87 }
88
89 // Currently, this test suite only passes when TensorFlow passes with CUDA/HIP,
90 // because otherwise the optimizer will not turn clearlist nodes to float16.
91 // When looking at clearlist nodes, this optimizer checks if the nodes have a
92 // float16 GPU OpKernel, but without CUDA/HIP there are no GPU OpKernels at all.
93 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
94
95 const std::pair<int, int> kMinGPUArch = {7, 0};
96
97 class AutoMixedPrecisionTest : public GrapplerTest {
98 protected:
SetUp()99 void SetUp() override {
100 int num_gpus = GetNumAvailableGPUs();
101 // If GPUs are available, require that they all satisfy the min arch.
102 gpu_available_ = (num_gpus > 0);
103 #if GOOGLE_CUDA
104 gpu_available_ =
105 gpu_available_ && (num_gpus == GetNumAvailableGPUs(kMinGPUArch));
106 #else // Here we force Tensorflow to use the virtual GFX906
107 gpu_available_ = false;
108 #endif
109 if (gpu_available_) {
110 virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 1));
111 } else {
112 DeviceProperties device_properties;
113 device_properties.set_type("GPU");
114 #if GOOGLE_CUDA
115 device_properties.mutable_environment()->insert({"architecture", "7"});
116 device_properties.mutable_environment()->insert({"cuda", "9010"});
117 #else
118 device_properties.mutable_environment()->insert(
119 {"architecture", "gfx906"});
120 #endif
121 virtual_cluster_.reset(
122 new VirtualCluster({{"/GPU:1", device_properties}}));
123 }
124 TF_CHECK_OK(virtual_cluster_->Provision());
125 }
126
TearDown()127 void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
128
AddSimpleNode(const string & name,const string & op,const std::vector<string> & inputs,GraphDef * graph) const129 NodeDef* AddSimpleNode(const string& name, const string& op,
130 const std::vector<string>& inputs,
131 GraphDef* graph) const {
132 std::vector<std::pair<string, AttrValue>> attributes;
133 if (op == "AddN" || op == "ShapeN") {
134 AttrValue num_inputs;
135 num_inputs.set_i(inputs.size());
136 attributes.emplace_back("N", num_inputs);
137 }
138 if (op == "ShapeN") {
139 AttrValue out_type;
140 out_type.set_type(DT_INT32);
141 attributes.emplace_back("out_type", out_type);
142 }
143 AttrValue type;
144 type.set_type(DT_FLOAT);
145 if (op == "Const" || op == "Placeholder" || op == "VariableV2" ||
146 op == "VarHandleOp" || op == "ReadVariableOp") {
147 attributes.emplace_back("dtype", type);
148 } else if (op == "SparseMatMul") {
149 attributes.emplace_back("Ta", type);
150 attributes.emplace_back("Tb", type);
151 } else if (op == "IdentityN") {
152 AttrValue type_list;
153 for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
154 type_list.mutable_list()->add_type(DT_FLOAT);
155 }
156 attributes.emplace_back("T", type_list);
157 } else if (op == "StackV2" || op == "StackPopV2") {
158 attributes.emplace_back("elem_type", type);
159 } else if (op == "Cast") {
160 attributes.emplace_back("SrcT", type);
161 attributes.emplace_back("DstT", type);
162 } else {
163 attributes.emplace_back("T", type);
164 }
165 return AddNode(name, op, inputs, attributes, graph);
166 }
167
TestSimpleUnaryInferOp(double input_min,double input_max,double atol,double rtol,const std::function<Output (const tensorflow::Scope &,Output)> & test_op_factory)168 void TestSimpleUnaryInferOp(
169 double input_min, double input_max, double atol, double rtol,
170 const std::function<Output(const tensorflow::Scope&, Output)>&
171 test_op_factory) {
172 int size = 128;
173 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
174 Output eye = ops::Const(s.WithOpName("eye"),
175 GenerateIdentityMatrix<DT_FLOAT>(size, size));
176 Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
177 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, eye);
178 Output infer1 = test_op_factory(s.WithOpName("infer1"), allow1);
179 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, eye);
180 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
181 GrapplerItem item;
182 item.fetch = {"fetch1"};
183 TF_CHECK_OK(s.ToGraphDef(&item.graph));
184 auto input_tensor = GenerateRandomTensorInRange<DT_FLOAT>(
185 TensorShape({size, size}), input_min, input_max);
186 std::vector<std::pair<string, Tensor>> feed = {{"input", input_tensor}};
187 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
188
189 AutoMixedPrecision optimizer;
190 GraphDef output;
191 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
192
193 VLOG(1) << output.DebugString();
194
195 GraphView output_view(&output);
196 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(),
197 DT_FLOAT);
198 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
199 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
200 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
201
202 auto tensors = EvaluateNodes(output, item.fetch, feed);
203 EXPECT_EQ(tensors.size(), tensors_expected.size());
204 EXPECT_EQ(tensors.size(), item.fetch.size());
205 for (int i = 0; i < item.fetch.size(); ++i) {
206 test::ExpectClose(tensors_expected[i], tensors[i], atol, rtol);
207 }
208 }
209
210 std::unique_ptr<Cluster> virtual_cluster_;
211 bool gpu_available_;
212 };
213
TEST_F(AutoMixedPrecisionTest,NoOp)214 TEST_F(AutoMixedPrecisionTest, NoOp) {
215 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
216 Output input = ops::Const(s.WithOpName("input"), 1.234f, {32});
217 Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
218 Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
219 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
220 Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
221 Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
222
223 GrapplerItem item;
224 item.fetch = {"fetch"};
225 TF_CHECK_OK(s.ToGraphDef(&item.graph));
226 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
227
228 AutoMixedPrecision optimizer;
229 GraphDef output;
230 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
231
232 VLOG(1) << output.DebugString();
233
234 VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
235
236 GraphView output_view(&output);
237 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
238 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
239 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
240 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
241 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
242
243 auto tensors = EvaluateNodes(output, item.fetch);
244 EXPECT_EQ(tensors.size(), tensors_expected.size());
245 EXPECT_EQ(tensors.size(), item.fetch.size());
246 for (int i = 0; i < item.fetch.size(); ++i) {
247 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
248 }
249 }
250
TEST_F(AutoMixedPrecisionTest,AlreadyFp16)251 TEST_F(AutoMixedPrecisionTest, AlreadyFp16) {
252 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
253 Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
254 Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_HALF);
255 Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
256 Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
257 Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
258 Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
259 Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
260
261 GrapplerItem item;
262 item.fetch = {"fetch"};
263 TF_CHECK_OK(s.ToGraphDef(&item.graph));
264 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
265
266 AutoMixedPrecision optimizer;
267 GraphDef output;
268 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
269 VLOG(1) << output.DebugString();
270
271 VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
272 GraphView output_view(&output);
273 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
274 EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
275 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
276 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
277 EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_HALF);
278 EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
279 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
280
281 auto tensors = EvaluateNodes(output, item.fetch);
282 EXPECT_EQ(tensors.size(), tensors_expected.size());
283 EXPECT_EQ(tensors.size(), item.fetch.size());
284 for (int i = 0; i < item.fetch.size(); ++i) {
285 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
286 }
287 }
288
TEST_F(AutoMixedPrecisionTest,Simple)289 TEST_F(AutoMixedPrecisionTest, Simple) {
290 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
291 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
292 Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
293 Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
294 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
295 Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
296 Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
297 Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
298 Output infer2 = ops::Log(s.WithOpName("infer2"), clr3);
299 Output clr4 = ops::Relu(s.WithOpName("clr4"), infer2);
300 Output deny2 = ops::SparseMatMul(s.WithOpName("deny2"), clr4, clr4);
301 Output clr5 = ops::Relu(s.WithOpName("clr5"), deny2);
302 Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
303
304 GrapplerItem item;
305 item.fetch = {"fetch"};
306 TF_CHECK_OK(s.ToGraphDef(&item.graph));
307 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
308
309 AutoMixedPrecision optimizer;
310 GraphDef output;
311 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
312
313 VLOG(1) << output.DebugString();
314
315 GraphView output_view(&output);
316 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
317 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
318 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
319 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
320 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
321 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
322 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
323 EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
324 EXPECT_EQ(output_view.GetNode("infer2")->attr().at("T").type(), DT_FLOAT);
325 EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
326 EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Ta").type(), DT_FLOAT);
327 EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Tb").type(), DT_FLOAT);
328 EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
329
330 auto tensors = EvaluateNodes(output, item.fetch);
331 EXPECT_EQ(tensors.size(), tensors_expected.size());
332 EXPECT_EQ(tensors.size(), item.fetch.size());
333 for (int i = 0; i < item.fetch.size(); ++i) {
334 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
335 }
336 }
337
TEST_F(AutoMixedPrecisionTest,NoInferOp)338 TEST_F(AutoMixedPrecisionTest, NoInferOp) {
339 setenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "TREAT_INFER_AS_DENY",
340 1 /* replace */);
341
342 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
343 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
344 Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
345 Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
346 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
347 Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
348 Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
349 Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
350 Output infer2 = ops::Log(s.WithOpName("infer2"), clr3);
351 Output clr4 = ops::Relu(s.WithOpName("clr4"), infer2);
352 Output allow2 = ops::MatMul(s.WithOpName("allow2"), clr4, clr4);
353 Output infer3 = ops::Log(s.WithOpName("infer3"), allow2);
354 Output fetch = ops::Identity(s.WithOpName("fetch"), infer3);
355
356 GrapplerItem item;
357 item.fetch = {"fetch"};
358 TF_CHECK_OK(s.ToGraphDef(&item.graph));
359 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
360
361 AutoMixedPrecision optimizer;
362 GraphDef output;
363 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
364
365 VLOG(1) << output.DebugString();
366
367 GraphView output_view(&output);
368 EXPECT_EQ(output.node_size(), item.graph.node_size() + 4);
369 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
370 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
371 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
372 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
373 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
374 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
375 EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
376 EXPECT_EQ(output_view.GetNode("infer2")->attr().at("T").type(), DT_FLOAT);
377 EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_HALF);
378 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
379 EXPECT_EQ(output_view.GetNode("infer3")->attr().at("T").type(), DT_FLOAT);
380
381 auto tensors = EvaluateNodes(output, item.fetch);
382 EXPECT_EQ(tensors.size(), tensors_expected.size());
383 EXPECT_EQ(tensors.size(), item.fetch.size());
384 for (int i = 0; i < item.fetch.size(); ++i) {
385 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
386 }
387 unsetenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL");
388 }
389
TEST_F(AutoMixedPrecisionTest,BidirectionalClearChain)390 TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) {
391 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
392 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
393 Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
394 Output clr2 = ops::Relu(s.WithOpName("clr2"), input);
395 Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
396 auto clr3 = ops::ShapeN(s.WithOpName("clr3"), {clr1, clr2});
397 Output clr4 = ops::Relu(s.WithOpName("clr4"), clr2);
398 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
399 Output fetch2 = ops::Identity(s.WithOpName("fetch2"), clr4);
400
401 GrapplerItem item;
402 item.fetch = {"fetch1", "fetch2"};
403 TF_CHECK_OK(s.ToGraphDef(&item.graph));
404 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
405
406 AutoMixedPrecision optimizer;
407 GraphDef output;
408 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
409
410 VLOG(1) << output.DebugString();
411
412 GraphView output_view(&output);
413 EXPECT_EQ(output.node_size(), item.graph.node_size() + 3);
414 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
415 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
416 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
417 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
418 EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
419 EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_HALF);
420
421 auto tensors = EvaluateNodes(output, item.fetch);
422 EXPECT_EQ(tensors.size(), tensors_expected.size());
423 EXPECT_EQ(tensors.size(), item.fetch.size());
424 for (int i = 0; i < item.fetch.size(); ++i) {
425 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
426 }
427 }
428
TEST_F(AutoMixedPrecisionTest,PreserveFetches)429 TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
430 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
431 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
432 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
433 Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
434 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
435 Output deny1 = ops::Exp(s.WithOpName("deny1"), infer1);
436 Output clr2 = ops::Relu(s.WithOpName("clr2"), deny1);
437 Output allow2 = ops::MatMul(s.WithOpName("allow2"), clr2, clr2);
438 Output clr3 = ops::Relu(s.WithOpName("clr3"), allow2);
439 Output deny2 = ops::Exp(s.WithOpName("deny2"), clr3);
440 Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
441
442 GrapplerItem item;
443 item.fetch = {"allow1", "clr2", "clr3"};
444 TF_CHECK_OK(s.ToGraphDef(&item.graph));
445 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
446
447 AutoMixedPrecision optimizer;
448 GraphDef output;
449 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
450
451 VLOG(1) << output.DebugString();
452
453 GraphView output_view(&output);
454 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
455 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
456 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
457 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
458 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
459 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
460 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
461 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
462 EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_FLOAT);
463 EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
464 EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
465
466 auto tensors = EvaluateNodes(output, item.fetch);
467 EXPECT_EQ(tensors.size(), tensors_expected.size());
468 EXPECT_EQ(tensors.size(), item.fetch.size());
469 for (int i = 0; i < item.fetch.size(); ++i) {
470 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-3);
471 }
472 }
473
TEST_F(AutoMixedPrecisionTest,PreserveCPUNodes)474 TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) {
475 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
476 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
477 Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
478 Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
479 Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
480 Output allow2 =
481 ops::MatMul(s.WithOpName("allow2").WithDevice(
482 "/job:localhost/replica:0/task:0/device:CPU:0"),
483 infer1, infer1);
484 Output clr2 = ops::Relu(s.WithOpName("clr2"), allow2);
485 Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
486
487 GrapplerItem item;
488 item.fetch = {"fetch"};
489 TF_CHECK_OK(s.ToGraphDef(&item.graph));
490 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
491
492 AutoMixedPrecision optimizer;
493 GraphDef output;
494 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
495
496 VLOG(1) << output.DebugString();
497
498 GraphView output_view(&output);
499 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
500 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
501 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
502 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
503 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
504 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_FLOAT);
505 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
506
507 auto tensors = EvaluateNodes(output, item.fetch);
508 EXPECT_EQ(tensors.size(), tensors_expected.size());
509 EXPECT_EQ(tensors.size(), item.fetch.size());
510 for (int i = 0; i < item.fetch.size(); ++i) {
511 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
512 }
513 }
514
TEST_F(AutoMixedPrecisionTest,PreserveIdentityAfterVariable)515 TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) {
516 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
517 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
518 Output var1 = ops::Variable(s.WithOpName("var1"), {32, 32}, DT_FLOAT);
519 Output clr1 = ops::Identity(s.WithOpName("clr1"), var1);
520 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, clr1);
521 Output input2 = ops::Const(s.WithOpName("input2"), 1.f / 32, {32, 32});
522 Output clr2 = ops::Identity(s.WithOpName("clr2"), input2);
523 Output allow2 = ops::MatMul(s.WithOpName("allow2"), input, clr2);
524 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
525 Output fetch2 = ops::Identity(s.WithOpName("fetch2"), allow2);
526
527 GrapplerItem item;
528 item.fetch = {"fetch1", "fetch2"};
529 TF_CHECK_OK(s.ToGraphDef(&item.graph));
530 auto var1_tensor =
531 GenerateConstantTensor<DT_FLOAT>(TensorShape({32, 32}), 3.141593f);
532 std::vector<std::pair<string, Tensor>> feed = {{"var1", var1_tensor}};
533 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
534
535 AutoMixedPrecision optimizer;
536 GraphDef output;
537 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
538
539 VLOG(1) << output.DebugString();
540
541 GraphView output_view(&output);
542 EXPECT_EQ(output.node_size(), item.graph.node_size() + 5);
543 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
544 EXPECT_EQ(output_view.GetNode("var1")->attr().at("dtype").type(), DT_FLOAT);
545 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
546 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
547 EXPECT_EQ(output_view.GetNode("input2")->attr().at("dtype").type(), DT_FLOAT);
548 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
549 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
550
551 auto tensors = EvaluateNodes(output, item.fetch, feed);
552 EXPECT_EQ(tensors.size(), tensors_expected.size());
553 EXPECT_EQ(tensors.size(), item.fetch.size());
554 for (int i = 0; i < item.fetch.size(); ++i) {
555 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-3);
556 }
557 }
558
TEST_F(AutoMixedPrecisionTest,FusedBatchNorm)559 TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
560 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
561 // Uses NHWC data format because non-GPU execution does not support NCHW.
562 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {8, 56, 56, 16});
563 Output weight = ops::Const(s.WithOpName("weight"), 2.f, {3, 3, 16, 16});
564 Output scale = ops::Const(s.WithOpName("scale"), 3.f, {16});
565 Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16});
566 Output mean = ops::Const(s.WithOpName("mean"), 5.f, {0});
567 Output variance = ops::Const(s.WithOpName("variance"), 6.f, {0});
568 Output allow1 =
569 ops::Conv2D(s.WithOpName("allow1"), input, weight, {1, 1, 1, 1}, "SAME",
570 ops::Conv2D::DataFormat("NHWC"));
571 auto fbn1_op =
572 ops::FusedBatchNorm(s.WithOpName("fbn1"), allow1, scale, offset, mean,
573 variance, ops::FusedBatchNorm::DataFormat("NHWC"));
574 Output fbn1 = fbn1_op.y;
575 Output fbn1_rs1 = fbn1_op.reserve_space_1;
576 Output fbn1_rs2 = fbn1_op.reserve_space_2;
577 Output bng1 = ops::FusedBatchNormGrad(
578 s.WithOpName("bng1"), fbn1, allow1, scale, fbn1_rs1,
579 fbn1_rs2, ops::FusedBatchNormGrad::DataFormat("NHWC"))
580 .x_backprop;
581 Output infer1 = ops::Add(s.WithOpName("infer1"), fbn1, bng1);
582 Output allow2 =
583 ops::Conv2D(s.WithOpName("allow2"), infer1, weight, {1, 1, 1, 1}, "SAME",
584 ops::Conv2D::DataFormat("NHWC"));
585 Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
586
587 GrapplerItem item;
588 item.fetch = {"fetch"};
589 TF_CHECK_OK(s.ToGraphDef(&item.graph));
590 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
591
592 AutoMixedPrecision optimizer;
593 GraphDef output;
594 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
595
596 VLOG(1) << output.DebugString();
597
598 GraphView output_view(&output);
599 EXPECT_EQ(output.node_size(), item.graph.node_size() + 3);
600 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
601 EXPECT_EQ(output_view.GetNode("fbn1")->op(), "FusedBatchNormV2");
602 EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("T").type(), DT_HALF);
603 EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("U").type(), DT_FLOAT);
604 EXPECT_EQ(output_view.GetNode("bng1")->op(), "FusedBatchNormGradV2");
605 EXPECT_EQ(output_view.GetNode("bng1")->attr().at("T").type(), DT_HALF);
606 EXPECT_EQ(output_view.GetNode("bng1")->attr().at("U").type(), DT_FLOAT);
607 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
608 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
609
610 auto tensors = EvaluateNodes(output, item.fetch);
611 EXPECT_EQ(tensors.size(), tensors_expected.size());
612 EXPECT_EQ(tensors.size(), item.fetch.size());
613 for (int i = 0; i < item.fetch.size(); ++i) {
614 test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-2);
615 }
616 }
617
TEST_F(AutoMixedPrecisionTest,RepeatedAndListTypeAttrs)618 TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
619 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
620 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
621 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
622 auto clr1_op = ops::IdentityN(s.WithOpName("clr1"), {allow1, allow1, allow1});
623 Output infer1 =
624 ops::AddN(s.WithOpName("infer1"),
625 {clr1_op.output[0], clr1_op.output[1], clr1_op.output[2]});
626 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
627 Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
628
629 GrapplerItem item;
630 item.fetch = {"fetch"};
631 TF_CHECK_OK(s.ToGraphDef(&item.graph));
632 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
633
634 AutoMixedPrecision optimizer;
635 GraphDef output;
636 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
637
638 VLOG(1) << output.DebugString();
639
640 GraphView output_view(&output);
641 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
642 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
643 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
644 for (auto type : output_view.GetNode("clr1")->attr().at("T").list().type()) {
645 EXPECT_EQ(type, DT_HALF);
646 }
647 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
648 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
649
650 auto tensors = EvaluateNodes(output, item.fetch);
651 EXPECT_EQ(tensors.size(), tensors_expected.size());
652 EXPECT_EQ(tensors.size(), item.fetch.size());
653 for (int i = 0; i < item.fetch.size(); ++i) {
654 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
655 }
656 }
657
TEST_F(AutoMixedPrecisionTest,ExistingCast)658 TEST_F(AutoMixedPrecisionTest, ExistingCast) {
659 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
660 Output input = ops::Const(s.WithOpName("input"), true, {32, 32});
661 Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_FLOAT);
662 Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
663 Output fetch = ops::Identity(s.WithOpName("fetch"), allow1);
664
665 GrapplerItem item;
666 item.fetch = {"fetch"};
667 TF_CHECK_OK(s.ToGraphDef(&item.graph));
668 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
669
670 AutoMixedPrecision optimizer;
671 GraphDef output;
672 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
673
674 VLOG(1) << output.DebugString();
675
676 GraphView output_view(&output);
677 EXPECT_EQ(output.node_size(), item.graph.node_size() + 1);
678 EXPECT_EQ(output_view.GetNode("cst1")->attr().at("SrcT").type(), DT_BOOL);
679 EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
680 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
681
682 auto tensors = EvaluateNodes(output, item.fetch);
683 EXPECT_EQ(tensors.size(), tensors_expected.size());
684 EXPECT_EQ(tensors.size(), item.fetch.size());
685 for (int i = 0; i < item.fetch.size(); ++i) {
686 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
687 }
688 }
689
TEST_F(AutoMixedPrecisionTest,RecurrentEdgeColorMismatch)690 TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) {
691 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
692 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
693 Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
694 Output ent1 =
695 ops::internal::Enter(s.WithOpName("ent1"), deny1, "loop1").output;
696 // Note that the second input is later replaced with "nxt1".
697 Output mrg1 = ops::Merge(s.WithOpName("mrg1"), {ent1, ent1}).output;
698 // For simplicity, the loop condition is constant false.
699 Output con1 = ops::Const(s.WithOpName("con1"), false, {});
700 Output lpc1 = ops::LoopCond(s.WithOpName("lpc1"), con1).output;
701 auto swt1 = ops::Switch(s.WithOpName("swt1"), mrg1, lpc1);
702 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), swt1.output_true);
703 Output allow1 = ops::MatMul(s.WithOpName("allow1"), infer1, infer1);
704 Output nxt1 = ops::NextIteration(s.WithOpName("nxt1"), allow1);
705 Output ext1 = ops::internal::Exit(s.WithOpName("ext1"), swt1.output_false);
706 Output fetch = ops::Identity(s.WithOpName("fetch"), ext1);
707 // Add a second merge node from the same NextIteration node. This case arises
708 // during graph optimization of some models.
709 auto mrg2 = ops::Merge(s.WithOpName("mrg2"), {ent1, nxt1});
710
711 GrapplerItem item;
712 item.fetch = {"fetch"};
713 TF_CHECK_OK(s.ToGraphDef(&item.graph));
714 NodeMap node_map_original(&item.graph);
715 auto merge_node = node_map_original.GetNode("mrg1");
716 // Modify the graph to create a loop.
717 merge_node->set_input(1, "nxt1");
718 // Add a control edge to ensure the loop condition is inside the frame.
719 auto const_node = node_map_original.GetNode("con1");
720 const_node->add_input("^mrg1");
721 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
722
723 AutoMixedPrecision optimizer;
724 GraphDef output;
725 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
726
727 VLOG(1) << output.DebugString();
728
729 GraphView output_view(&output);
730 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
731 // Note that mrg1 gets painted deny because it is between deny1 and infer1.
732 // This forces nxt1 and mrg2 to be painted deny as well (they would otherwise
733 // be painted allow because they are clear and have a direct path to allow1).
734 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
735 EXPECT_EQ(output_view.GetNode("ent1")->attr().at("T").type(), DT_FLOAT);
736 EXPECT_EQ(output_view.GetNode("mrg1")->attr().at("T").type(), DT_FLOAT);
737 EXPECT_EQ(output_view.GetNode("swt1")->attr().at("T").type(), DT_FLOAT);
738 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
739 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
740 EXPECT_EQ(output_view.GetNode("nxt1")->attr().at("T").type(), DT_FLOAT);
741 EXPECT_EQ(output_view.GetNode("ext1")->attr().at("T").type(), DT_FLOAT);
742 EXPECT_EQ(output_view.GetNode("mrg2")->attr().at("T").type(), DT_FLOAT);
743
744 auto tensors = EvaluateNodes(output, item.fetch);
745 EXPECT_EQ(tensors.size(), tensors_expected.size());
746 EXPECT_EQ(tensors.size(), item.fetch.size());
747 for (int i = 0; i < item.fetch.size(); ++i) {
748 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
749 }
750 }
751
TEST_F(AutoMixedPrecisionTest,TensorListSetGet)752 TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
753 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
754 tensorflow::Input shape = {32, 32};
755 auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
756 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
757 Output idx1 = ops::Const(s.WithOpName("idx1"), 1);
758 Output idx2 = ops::Const(s.WithOpName("idx2"), 2);
759 Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
760 auto tl1w1 =
761 ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
762 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
763 auto tl1w2 =
764 ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
765 // Ensure that TensorListResize doesn't cause any problems.
766 Output tl1rs =
767 ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
768 Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
769 shape, DT_FLOAT)
770 .item;
771 Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
772 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
773 auto tl1w3 =
774 ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
775 Output tl1r2 =
776 ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
777 shape, DT_FLOAT)
778 .item;
779 auto tl2 = ops::TensorListReserve(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
780 auto tl2w1 =
781 ops::TensorListSetItem(s.WithOpName("tl2w1"), tl2.handle, idx1, input);
782 Output tl2r1 =
783 ops::TensorListGetItem(s.WithOpName("tl2r1"), tl2w1.output_handle, idx1,
784 shape, DT_FLOAT)
785 .item;
786 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
787 Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
788
789 GrapplerItem item;
790 item.fetch = {"fetch1", "fetch2"};
791 TF_CHECK_OK(s.ToGraphDef(&item.graph));
792 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
793
794 AutoMixedPrecision optimizer;
795 GraphDef output;
796 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
797
798 VLOG(1) << output.DebugString();
799
800 GraphView output_view(&output);
801 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
802 const char* type_key = "element_dtype";
803 EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
804 EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
805 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
806 EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
807 EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
808 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
809 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
810 EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
811 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
812 EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
813 EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
814
815 auto tensors = EvaluateNodes(output, item.fetch);
816 EXPECT_EQ(tensors.size(), tensors_expected.size());
817 EXPECT_EQ(tensors.size(), item.fetch.size());
818 for (int i = 0; i < item.fetch.size(); ++i) {
819 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
820 }
821 }
822
TEST_F(AutoMixedPrecisionTest,TensorListPushPop)823 TEST_F(AutoMixedPrecisionTest, TensorListPushPop) {
824 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
825 tensorflow::Input shape = {32, 32};
826 auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
827 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
828 auto tl1w1 =
829 ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, input);
830 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
831 auto tl1w2 = ops::TensorListPushBack(s.WithOpName("tl1w2"),
832 tl1w1.output_handle, allow1);
833 Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"),
834 tl1w2.output_handle, shape, DT_FLOAT)
835 .tensor;
836 Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
837 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
838 auto tl1w3 =
839 ops::TensorListPushBack(s.WithOpName("tl1w3"), tl1.handle, allow2);
840 Output tl1r2 = ops::TensorListPopBack(s.WithOpName("tl1r2"),
841 tl1w3.output_handle, shape, DT_FLOAT)
842 .tensor;
843 auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
844 auto tl2w1 =
845 ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, input);
846 Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
847 tl2w1.output_handle, shape, DT_FLOAT)
848 .tensor;
849 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
850 Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
851
852 GrapplerItem item;
853 item.fetch = {"fetch1", "fetch2"};
854 TF_CHECK_OK(s.ToGraphDef(&item.graph));
855 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
856
857 AutoMixedPrecision optimizer;
858 GraphDef output;
859 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
860
861 VLOG(1) << output.DebugString();
862
863 GraphView output_view(&output);
864 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
865 const char* type_key = "element_dtype";
866 EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
867 EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
868 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
869 EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
870 EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
871 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
872 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
873 EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
874 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
875 EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
876 EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
877
878 auto tensors = EvaluateNodes(output, item.fetch);
879 EXPECT_EQ(tensors.size(), tensors_expected.size());
880 EXPECT_EQ(tensors.size(), item.fetch.size());
881 for (int i = 0; i < item.fetch.size(); ++i) {
882 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
883 }
884 }
885
TEST_F(AutoMixedPrecisionTest,TensorListFromTensor)886 TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) {
887 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
888 tensorflow::Input shape = {32};
889 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
890 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
891 auto tl1 = ops::TensorListFromTensor(s.WithOpName("tl1"), allow1, shape);
892 Output tl1r1 = ops::TensorListStack(s.WithOpName("tl1r1"), tl1.output_handle,
893 shape, DT_FLOAT)
894 .tensor;
895 Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
896 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
897 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
898
899 // This tests that a allow-painted object node (tl2) will force an unpainted
900 // client node (tl2w1) to be painted allow as well. (Without the force, tl2w1
901 // would remain unpainted, producing an invalid graph).
902 auto tl2 = ops::TensorListFromTensor(s.WithOpName("tl2"), allow1, shape);
903 auto tl2w1 =
904 ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.output_handle, input);
905
906 GrapplerItem item;
907 item.fetch = {"fetch1"};
908 TF_CHECK_OK(s.ToGraphDef(&item.graph));
909 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
910
911 AutoMixedPrecision optimizer;
912 GraphDef output;
913 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
914
915 VLOG(1) << output.DebugString();
916
917 GraphView output_view(&output);
918 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
919 const char* type_key = "element_dtype";
920 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
921 EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
922 EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
923 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
924 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
925 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
926 EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
927
928 auto tensors = EvaluateNodes(output, item.fetch);
929 EXPECT_EQ(tensors.size(), tensors_expected.size());
930 EXPECT_EQ(tensors.size(), item.fetch.size());
931 for (int i = 0; i < item.fetch.size(); ++i) {
932 test::ExpectClose(tensors_expected[i], tensors[i], -1, 2e-4);
933 }
934 }
935
TEST_F(AutoMixedPrecisionTest,TensorListPushBackBatchAndConcatLists)936 TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
937 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
938 tensorflow::Input shape = {32, 32};
939 auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
940 auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
941 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
942 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
943 Output tl1_tl2 =
944 ops::Stack(s.WithOpName("tl1_tl2"), {tl1.handle, tl2.handle});
945 Output allow1_allow1 =
946 ops::Stack(s.WithOpName("allow1_allow1"), {allow1, allow1});
947 auto tl12w1 = ops::TensorListPushBackBatch(s.WithOpName("tl12w1"), tl1_tl2,
948 allow1_allow1);
949 OutputList tl12w1_outputs =
950 ops::Split(s.WithOpName("tl12w1_outputs"), 0, tl12w1.output_handles, 2)
951 .output;
952 Output scalar_shape = ops::Const(s.WithOpName("scalar_shape"), 0, {0});
953 Output tl12w1_output0 = ops::Reshape(s.WithOpName("tl12w1_output0"),
954 tl12w1_outputs[0], scalar_shape);
955 Output tl12w1_output1 = ops::Reshape(s.WithOpName("tl12w1_output1"),
956 tl12w1_outputs[1], scalar_shape);
957 Output tl3 = ops::TensorListConcatLists(s.WithOpName("tl3"), tl12w1_output0,
958 tl12w1_output1, DT_FLOAT);
959 Output tl3r1 =
960 ops::TensorListPopBack(s.WithOpName("tl3r1"), tl3, shape, DT_FLOAT)
961 .tensor;
962 Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl3r1);
963 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
964 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
965
966 GrapplerItem item;
967 item.fetch = {"fetch1"};
968 TF_CHECK_OK(s.ToGraphDef(&item.graph));
969 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
970
971 AutoMixedPrecision optimizer;
972 GraphDef output;
973 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
974
975 VLOG(1) << output.DebugString();
976
977 GraphView output_view(&output);
978 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
979 const char* type_key = "element_dtype";
980 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
981 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
982 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
983 EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
984 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
985 EXPECT_EQ(output_view.GetNode("tl3")->attr().at(type_key).type(), DT_HALF);
986 EXPECT_EQ(output_view.GetNode("tl3r1")->attr().at(type_key).type(), DT_HALF);
987
988 auto tensors = EvaluateNodes(output, item.fetch);
989 EXPECT_EQ(tensors.size(), tensors_expected.size());
990 EXPECT_EQ(tensors.size(), item.fetch.size());
991 for (int i = 0; i < item.fetch.size(); ++i) {
992 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
993 }
994 }
995
TEST_F(AutoMixedPrecisionTest,TensorListThroughFunction)996 TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
997 // This test passes a tensor list handle through a function with its own
998 // Tensor List ops inside to test that the types are not changed to a
999 // conflicting state.
1000 // A separate Tensor List cluster is added to test that it is still changed to
1001 // DT_HALF.
1002 FunctionDefLibrary function_lib;
1003 const Tensor kShape = test::AsTensor<int32>({32, 32});
1004 FunctionDef func1 = FunctionDefHelper::Define(
1005 "Func1", {"ihandle: variant", "x: float"},
1006 {"ohandle: variant", "y: float"}, {},
1007 {
1008 {{"tl1w1_handle"},
1009 "TensorListPushBack",
1010 {"ihandle", "x"},
1011 {{"element_dtype", DT_FLOAT}}},
1012 {{"shape"}, "Const", {}, {{"value", kShape}, {"dtype", DT_INT32}}},
1013 {{"tl1r1_handle", "tl1r1_data"},
1014 "TensorListPopBack",
1015 {"tl1w1_handle", "shape"},
1016 {{"element_dtype", DT_FLOAT}}},
1017 {{"ohandle"}, "Identity", {"tl1r1_handle"}, {{"T", DT_VARIANT}}},
1018 {{"y"}, "Identity", {"tl1r1_data"}, {{"T", DT_FLOAT}}},
1019 });
1020 function_lib.add_function()->Swap(&func1);
1021
1022 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1023 TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib));
1024 tensorflow::Input shape = {32, 32};
1025 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1026 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
1027 Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
1028 auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
1029 auto tl1w1 =
1030 ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, infer1);
1031 auto _infer1 = tensorflow::ops::AsNodeOut(s, infer1);
1032 auto _tl1w1_handle = tensorflow::ops::AsNodeOut(s, tl1w1.output_handle);
1033 auto builder =
1034 tensorflow::NodeBuilder("Func1", "Func1", s.graph()->op_registry());
1035 tensorflow::Node* func1_op;
1036 TF_CHECK_OK(builder.Input(_tl1w1_handle)
1037 .Input(_infer1)
1038 .Finalize(s.graph(), &func1_op));
1039 Output func1_handle(func1_op, 0);
1040 Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"), func1_handle,
1041 shape, DT_FLOAT)
1042 .tensor;
1043 auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
1044 auto tl2w1 =
1045 ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, infer1);
1046 Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
1047 tl2w1.output_handle, shape, DT_FLOAT)
1048 .tensor;
1049 Output allow2 = ops::MatMul(s.WithOpName("allow2"), tl1r1, tl2r1);
1050 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
1051
1052 GrapplerItem item;
1053 item.fetch = {"fetch1"};
1054 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1055 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1056
1057 AutoMixedPrecision optimizer;
1058 GraphDef output;
1059 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1060
1061 VLOG(1) << output.DebugString();
1062
1063 GraphView output_view(&output);
1064 const char* type_key = "element_dtype";
1065 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
1066 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
1067 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
1068 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
1069 EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
1070 EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_HALF);
1071
1072 auto tensors = EvaluateNodes(output, item.fetch);
1073 EXPECT_EQ(tensors.size(), tensors_expected.size());
1074 EXPECT_EQ(tensors.size(), item.fetch.size());
1075 for (int i = 0; i < item.fetch.size(); ++i) {
1076 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
1077 }
1078 }
1079
GetCudaVersion(const Cluster & cluster)1080 int GetCudaVersion(const Cluster& cluster) {
1081 auto devices = cluster.GetDevices();
1082 for (const auto& device : devices) {
1083 const DeviceProperties& device_properties = device.second;
1084 if (device_properties.type() == "GPU") {
1085 const auto& device_env = device_properties.environment();
1086 auto it = device_env.find("cuda");
1087 if (it != device_env.end()) {
1088 string cuda_version_str = it->second;
1089 return std::stoi(cuda_version_str);
1090 }
1091 }
1092 }
1093 return 0;
1094 }
1095
IsSupportedGPU(const Cluster & cluster)1096 bool IsSupportedGPU(const Cluster& cluster) {
1097 #ifdef GOOGLE_CUDA
1098 return GetCudaVersion(cluster) >= 9010;
1099 #else
1100 return true;
1101 #endif
1102 }
1103
TEST_F(AutoMixedPrecisionTest,BatchMatMul)1104 TEST_F(AutoMixedPrecisionTest, BatchMatMul) {
1105 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1106 Output input = ops::Const(s.WithOpName("input"), 1.f / 33, {64, 32, 32});
1107 Output allow1 = ops::BatchMatMul(s.WithOpName("allow1"), input, input);
1108 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
1109
1110 GrapplerItem item;
1111 item.fetch = {"fetch1"};
1112 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1113 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1114
1115 AutoMixedPrecision optimizer;
1116 GraphDef output;
1117 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1118
1119 VLOG(1) << output.DebugString();
1120
1121 GraphView output_view(&output);
1122 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1123 if (IsSupportedGPU(*virtual_cluster_.get())) {
1124 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1125 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
1126 } else {
1127 EXPECT_EQ(output.node_size(), item.graph.node_size());
1128 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
1129 }
1130
1131 auto tensors = EvaluateNodes(output, item.fetch);
1132 EXPECT_EQ(tensors.size(), tensors_expected.size());
1133 EXPECT_EQ(tensors.size(), item.fetch.size());
1134 for (int i = 0; i < item.fetch.size(); ++i) {
1135 test::ExpectClose(tensors_expected[i], tensors[i], -1, 3.0e-3);
1136 }
1137 }
1138
TEST_F(AutoMixedPrecisionTest,EluOp)1139 TEST_F(AutoMixedPrecisionTest, EluOp) {
1140 TestSimpleUnaryInferOp(
1141 -5, 5, 1.0e-3, 1.0e-3,
1142 [](const tensorflow::Scope& scope, Output input) -> Output {
1143 return ops::Elu(scope, input);
1144 });
1145 }
1146
TEST_F(AutoMixedPrecisionTest,ErfOp)1147 TEST_F(AutoMixedPrecisionTest, ErfOp) {
1148 TestSimpleUnaryInferOp(
1149 -5, 5, 1.0e-3, -1,
1150 [](const tensorflow::Scope& scope, Output input) -> Output {
1151 return ops::Erf(scope, input);
1152 });
1153 }
1154
TEST_F(AutoMixedPrecisionTest,ErfcOp)1155 TEST_F(AutoMixedPrecisionTest, ErfcOp) {
1156 TestSimpleUnaryInferOp(
1157 -5, 5, 1.0e-3, -1,
1158 [](const tensorflow::Scope& scope, Output input) -> Output {
1159 return ops::Erfc(scope, input);
1160 });
1161 }
1162
TEST_F(AutoMixedPrecisionTest,InvOp)1163 TEST_F(AutoMixedPrecisionTest, InvOp) {
1164 TestSimpleUnaryInferOp(
1165 0.01, 10, -1, 1.0e-3,
1166 [](const tensorflow::Scope& scope, Output input) -> Output {
1167 return ops::Inv(scope, input);
1168 });
1169 }
1170
TEST_F(AutoMixedPrecisionTest,LogOp)1171 TEST_F(AutoMixedPrecisionTest, LogOp) {
1172 TestSimpleUnaryInferOp(
1173 0.01, 10, 1.0e-3, 2.0e-3,
1174 [](const tensorflow::Scope& scope, Output input) -> Output {
1175 return ops::Log(scope, input);
1176 });
1177 }
1178
TEST_F(AutoMixedPrecisionTest,Log1pOp)1179 TEST_F(AutoMixedPrecisionTest, Log1pOp) {
1180 TestSimpleUnaryInferOp(
1181 -0.99, 9, 1.0e-3, 5.0e-3,
1182 [](const tensorflow::Scope& scope, Output input) -> Output {
1183 return ops::Log1p(scope, input);
1184 });
1185 }
1186
TEST_F(AutoMixedPrecisionTest,LogSoftmaxOp)1187 TEST_F(AutoMixedPrecisionTest, LogSoftmaxOp) {
1188 TestSimpleUnaryInferOp(
1189 -8, 8, -1, 1.0e-2,
1190 [](const tensorflow::Scope& scope, Output input) -> Output {
1191 return ops::LogSoftmax(scope, input);
1192 });
1193 }
1194
TEST_F(AutoMixedPrecisionTest,ReciprocalOp)1195 TEST_F(AutoMixedPrecisionTest, ReciprocalOp) {
1196 TestSimpleUnaryInferOp(
1197 0.01, 10, -1, 1.0e-3,
1198 [](const tensorflow::Scope& scope, Output input) -> Output {
1199 return ops::Reciprocal(scope, input);
1200 });
1201 }
1202
TEST_F(AutoMixedPrecisionTest,SigmoidOp)1203 TEST_F(AutoMixedPrecisionTest, SigmoidOp) {
1204 TestSimpleUnaryInferOp(
1205 -5, 5, 1.0e-3, -1,
1206 [](const tensorflow::Scope& scope, Output input) -> Output {
1207 return ops::Sigmoid(scope, input);
1208 });
1209 }
1210
TEST_F(AutoMixedPrecisionTest,SoftmaxOp)1211 TEST_F(AutoMixedPrecisionTest, SoftmaxOp) {
1212 TestSimpleUnaryInferOp(
1213 -8, 8, 2.0e-3, -1,
1214 [](const tensorflow::Scope& scope, Output input) -> Output {
1215 return ops::Softmax(scope, input);
1216 });
1217 }
1218
TEST_F(AutoMixedPrecisionTest,SoftplusOp)1219 TEST_F(AutoMixedPrecisionTest, SoftplusOp) {
1220 TestSimpleUnaryInferOp(
1221 -5, 5, 1.0e-3, 1.0e-3,
1222 [](const tensorflow::Scope& scope, Output input) -> Output {
1223 return ops::Softplus(scope, input);
1224 });
1225 }
1226
TEST_F(AutoMixedPrecisionTest,SqrtOp)1227 TEST_F(AutoMixedPrecisionTest, SqrtOp) {
1228 TestSimpleUnaryInferOp(
1229 0, 10, 1.0e-3, 1.0e-3,
1230 [](const tensorflow::Scope& scope, Output input) -> Output {
1231 return ops::Sqrt(scope, input);
1232 });
1233 }
1234
TEST_F(AutoMixedPrecisionTest,TanhOp)1235 TEST_F(AutoMixedPrecisionTest, TanhOp) {
1236 TestSimpleUnaryInferOp(
1237 -5, 5, 1.0e-3, -1,
1238 [](const tensorflow::Scope& scope, Output input) -> Output {
1239 return ops::Tanh(scope, input);
1240 });
1241 }
1242
1243 class AutoMixedPrecisionCpuTest : public GrapplerTest {
1244 protected:
SetUp()1245 void SetUp() override {
1246 virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0));
1247 TF_CHECK_OK(virtual_cluster_->Provision());
1248 }
TearDown()1249 void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
1250
1251 std::unique_ptr<Cluster> virtual_cluster_;
1252 };
1253
TEST_F(AutoMixedPrecisionCpuTest,Simple)1254 TEST_F(AutoMixedPrecisionCpuTest, Simple) {
1255 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1256 "/job:localhost/replica:0/task:0/device:CPU:0");
1257 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1258 Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
1259 Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
1260 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
1261 Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
1262 Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
1263 Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
1264 Output infer2 = ops::Log(s.WithOpName("infer2"), clr3);
1265 Output clr4 = ops::Relu(s.WithOpName("clr4"), infer2);
1266 Output deny2 = ops::SparseMatMul(s.WithOpName("deny2"), clr4, clr4);
1267 Output clr5 = ops::Relu(s.WithOpName("clr5"), deny2);
1268 Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
1269
1270 GrapplerItem item;
1271 item.fetch = {"fetch"};
1272 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1273 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1274
1275 AutoMixedPrecision optimizer{AutoMixedPrecisionMode::CPU};
1276 GraphDef output;
1277 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1278
1279 VLOG(1) << output.DebugString();
1280
1281 const int expected_cast_ops = 9;
1282 EXPECT_EQ(output.node_size(), item.graph.node_size() + expected_cast_ops);
1283
1284 GraphView output_view(&output);
1285 // Matmul is a FP32 op now
1286 auto matmul_op = output_view.GetNode("allow1");
1287 EXPECT_EQ(matmul_op->attr().at("T").type(), DT_FLOAT);
1288 for (auto edge : output_view.GetFaninEdges(*matmul_op, false)) {
1289 EXPECT_EQ(edge.src.node->op(), "Cast");
1290 EXPECT_EQ(edge.src.node->attr().at("SrcT").type(), DT_HALF);
1291 EXPECT_EQ(edge.src.node->attr().at("DstT").type(), DT_FLOAT);
1292 }
1293 for (auto edge : output_view.GetFanoutEdges(*matmul_op, false)) {
1294 EXPECT_EQ(edge.dst.node->op(), "Cast");
1295 EXPECT_EQ(edge.dst.node->attr().at("SrcT").type(), DT_FLOAT);
1296 EXPECT_EQ(edge.dst.node->attr().at("DstT").type(), DT_HALF);
1297 }
1298 }
1299
TEST_F(AutoMixedPrecisionCpuTest,MixedFanout)1300 TEST_F(AutoMixedPrecisionCpuTest, MixedFanout) {
1301 // Test when an FP16 allowed node has a mixed fanout of FP16 allowed node and
1302 // FP32 node.
1303 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1304 "/job:localhost/replica:0/task:0/device:CPU:0");
1305 Output input1 = ops::Const(s.WithOpName("input1"), 1.f / 32, {32, 32});
1306 Output input2 = ops::Const(s.WithOpName("input2"), 2.f / 32, {32, 32});
1307 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input1, input2);
1308 Output allow2 = ops::MatMul(s.WithOpName("allow2"), allow1, input2);
1309 Output deny = ops::Exp(s.WithOpName("deny"), allow1);
1310 Output infer = ops::Add(s.WithOpName("infer"), deny, allow2);
1311 Output fetch = ops::Identity(s.WithOpName("fetch"), infer);
1312
1313 GrapplerItem item;
1314 item.fetch = {"fetch"};
1315 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1316 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1317
1318 AutoMixedPrecision optimizer{AutoMixedPrecisionMode::CPU};
1319 GraphDef output;
1320 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1321
1322 VLOG(1) << output.DebugString();
1323
1324 const int expected_cast_ops = 10;
1325 EXPECT_EQ(output.node_size(), item.graph.node_size() + expected_cast_ops);
1326
1327 GraphView output_view(&output);
1328 auto allow1_op = output_view.GetNode("allow1");
1329 for (auto edge : output_view.GetFaninEdges(*allow1_op, false)) {
1330 EXPECT_EQ(edge.src.node->op(), "Cast");
1331 EXPECT_EQ(edge.src.node->attr().at("SrcT").type(), DT_HALF);
1332 EXPECT_EQ(edge.src.node->attr().at("DstT").type(), DT_FLOAT);
1333 }
1334 for (auto edge : output_view.GetFanoutEdges(*allow1_op, false)) {
1335 EXPECT_EQ(edge.dst.node->op(), "Cast");
1336 EXPECT_EQ(edge.dst.node->attr().at("SrcT").type(), DT_FLOAT);
1337 EXPECT_EQ(edge.dst.node->attr().at("DstT").type(), DT_HALF);
1338 }
1339 auto deny_op = output_view.GetNode("deny");
1340 for (auto edge : output_view.GetFaninEdges(*deny_op, false)) {
1341 EXPECT_EQ(edge.src.node->op(), "Cast");
1342 EXPECT_EQ(edge.src.node->attr().at("SrcT").type(), DT_HALF);
1343 EXPECT_EQ(edge.src.node->attr().at("DstT").type(), DT_FLOAT);
1344 }
1345 for (auto edge : output_view.GetFanoutEdges(*deny_op, false)) {
1346 EXPECT_NE(edge.dst.node->op(), "Cast");
1347 }
1348 }
1349
1350 class AutoMixedPrecisionSimulateGpuTest : public GrapplerTest {
1351 protected:
SetUp()1352 void SetUp() override {
1353 std::unordered_map<string, DeviceProperties> devices;
1354 DeviceProperties cpu_device;
1355 cpu_device.set_type("CPU");
1356 cpu_device.set_frequency(1000);
1357 cpu_device.set_num_cores(4);
1358 cpu_device.set_memory_size(1024 * 1024);
1359 devices["/job:localhost/replica:0/task:0/device:CPU:0"] = cpu_device;
1360 // Explicitly creating machine without GPU.
1361 virtual_cluster_.reset(new VirtualCluster(devices));
1362 TF_CHECK_OK(virtual_cluster_->Provision());
1363 }
TearDown()1364 void TearDown() override {
1365 unsetenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU");
1366 TF_CHECK_OK(virtual_cluster_->Shutdown());
1367 }
1368
1369 std::unique_ptr<Cluster> virtual_cluster_;
1370
TestSimple(tensorflow::Scope s,bool is_optimized)1371 void TestSimple(tensorflow::Scope s, bool is_optimized) {
1372 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1373 Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
1374 Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
1375 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
1376 Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
1377 Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
1378 Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
1379 Output infer2 = ops::Log(s.WithOpName("infer2"), clr3);
1380 Output clr4 = ops::Relu(s.WithOpName("clr4"), infer2);
1381 Output deny2 = ops::SparseMatMul(s.WithOpName("deny2"), clr4, clr4);
1382 Output clr5 = ops::Relu(s.WithOpName("clr5"), deny2);
1383 Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
1384
1385 GrapplerItem item;
1386 item.fetch = {"fetch"};
1387 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1388 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1389
1390 GraphDef output;
1391 AutoMixedPrecision optimizer;
1392 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1393
1394 VLOG(1) << output.DebugString();
1395
1396 GraphView output_view(&output);
1397 DataType expected_data_type = is_optimized ? DT_HALF : DT_FLOAT;
1398 int expected_graph_size =
1399 is_optimized ? item.graph.node_size() + 2 : item.graph.node_size();
1400
1401 EXPECT_EQ(output.node_size(), expected_graph_size);
1402 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(),
1403 DT_FLOAT);
1404 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
1405 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
1406 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
1407 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(),
1408 expected_data_type);
1409 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(),
1410 expected_data_type);
1411 EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(),
1412 expected_data_type);
1413 EXPECT_EQ(output_view.GetNode("infer2")->attr().at("T").type(), DT_FLOAT);
1414 EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
1415 EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Ta").type(), DT_FLOAT);
1416 EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Tb").type(), DT_FLOAT);
1417 EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
1418
1419 auto tensors = EvaluateNodes(output, item.fetch);
1420 EXPECT_EQ(tensors.size(), tensors_expected.size());
1421 EXPECT_EQ(tensors.size(), item.fetch.size());
1422 for (int i = 0; i < item.fetch.size(); ++i) {
1423 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
1424 }
1425 }
1426 };
1427
TEST_F(AutoMixedPrecisionSimulateGpuTest,Simple_NoGpu)1428 TEST_F(AutoMixedPrecisionSimulateGpuTest, Simple_NoGpu) {
1429 TestSimple(tensorflow::Scope::NewRootScope(), /* is_optimized= */ false);
1430 }
1431
TEST_F(AutoMixedPrecisionSimulateGpuTest,Simple_SimulatedGpu)1432 TEST_F(AutoMixedPrecisionSimulateGpuTest, Simple_SimulatedGpu) {
1433 setenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU", "true",
1434 1 /* replace */);
1435 TestSimple(tensorflow::Scope::NewRootScope(), /* is_optimized= */ true);
1436 }
1437
TEST_F(AutoMixedPrecisionSimulateGpuTest,Simple_SimulatedGpu_CpuScope)1438 TEST_F(AutoMixedPrecisionSimulateGpuTest, Simple_SimulatedGpu_CpuScope) {
1439 setenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU", "true",
1440 1 /* replace */);
1441 TestSimple(tensorflow::Scope::NewRootScope().WithDevice(
1442 "/job:localhost/replica:0/task:0/device:CPU:0"),
1443 /* is_optimized= */ false);
1444 }
1445
1446 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1447
1448 #if INTEL_MKL
1449
1450 class AutoMixedPrecisionMklTest : public GrapplerTest {
1451 protected:
SetUp()1452 void SetUp() override {
1453 virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0));
1454 TF_CHECK_OK(virtual_cluster_->Provision());
1455 }
TearDown()1456 void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
1457
1458 std::unique_ptr<Cluster> virtual_cluster_;
1459 };
1460
TEST_F(AutoMixedPrecisionMklTest,AlreadyBf16)1461 TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) {
1462 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1463 "/job:localhost/replica:0/task:0/device:CPU:0");
1464 Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
1465 Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_BFLOAT16);
1466 Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
1467 Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
1468 Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
1469 Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
1470 Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
1471
1472 GrapplerItem item;
1473 item.fetch = {"fetch"};
1474 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1475 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1476
1477 AutoMixedPrecision optimizer{AutoMixedPrecisionMode::BF16};
1478 GraphDef output;
1479 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1480 VLOG(1) << output.DebugString();
1481
1482 VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
1483 GraphView output_view(&output);
1484 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1485 EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_BFLOAT16);
1486 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1487 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_BFLOAT16);
1488 EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_BFLOAT16);
1489 EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
1490 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
1491
1492 auto tensors = EvaluateNodes(output, item.fetch);
1493 EXPECT_EQ(tensors.size(), tensors_expected.size());
1494 EXPECT_EQ(tensors.size(), item.fetch.size());
1495 for (int i = 0; i < item.fetch.size(); ++i) {
1496 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
1497 }
1498 }
1499
TEST_F(AutoMixedPrecisionMklTest,Simple)1500 TEST_F(AutoMixedPrecisionMklTest, Simple) {
1501 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1502 "/job:localhost/replica:0/task:0/device:CPU:0");
1503 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1504 Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
1505 Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
1506 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
1507 Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
1508 Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
1509 Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
1510 Output deny2 = ops::Log(s.WithOpName("deny2"), clr3);
1511 Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
1512 Output deny3 = ops::SparseMatMul(s.WithOpName("deny3"), clr4, clr4);
1513 Output clr5 = ops::Relu(s.WithOpName("clr5"), deny3);
1514 Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
1515
1516 GrapplerItem item;
1517 item.fetch = {"fetch"};
1518 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1519 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1520
1521 AutoMixedPrecision optimizer{AutoMixedPrecisionMode::BF16};
1522 GraphDef output;
1523 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1524
1525 VLOG(1) << output.DebugString();
1526
1527 GraphView output_view(&output);
1528 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1529 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1530 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
1531 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
1532 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
1533 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_BFLOAT16);
1534 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1535 EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_BFLOAT16);
1536 EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
1537 EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
1538 EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Ta").type(), DT_FLOAT);
1539 EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Tb").type(), DT_FLOAT);
1540 EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
1541
1542 auto tensors = EvaluateNodes(output, item.fetch);
1543 EXPECT_EQ(tensors.size(), tensors_expected.size());
1544 EXPECT_EQ(tensors.size(), item.fetch.size());
1545 for (int i = 0; i < item.fetch.size(); ++i) {
1546 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
1547 }
1548 }
1549
TEST_F(AutoMixedPrecisionMklTest,TensorListSetGet)1550 TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
1551 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1552 "/job:localhost/replica:0/task:0/device:CPU:0");
1553 tensorflow::Input shape = {32, 32};
1554 auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
1555 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1556 Output idx1 = ops::Const(s.WithOpName("idx1"), 1);
1557 Output idx2 = ops::Const(s.WithOpName("idx2"), 2);
1558 Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
1559 auto tl1w1 =
1560 ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
1561 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
1562 auto tl1w2 =
1563 ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
1564 // Ensure that TensorListResize doesn't cause any problems.
1565 Output tl1rs =
1566 ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
1567 Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
1568 shape, DT_FLOAT)
1569 .item;
1570 Output infer1 = ops::Mul(s.WithOpName("infer1"), tl1r1, tl1r1);
1571 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
1572 auto tl1w3 =
1573 ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
1574 Output tl1r2 =
1575 ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
1576 shape, DT_FLOAT)
1577 .item;
1578 auto tl2 = ops::TensorListReserve(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
1579 auto tl2w1 =
1580 ops::TensorListSetItem(s.WithOpName("tl2w1"), tl2.handle, idx1, input);
1581 Output tl2r1 =
1582 ops::TensorListGetItem(s.WithOpName("tl2r1"), tl2w1.output_handle, idx1,
1583 shape, DT_FLOAT)
1584 .item;
1585 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
1586 Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
1587
1588 GrapplerItem item;
1589 item.fetch = {"fetch1", "fetch2"};
1590 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1591 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1592
1593 AutoMixedPrecision optimizer{AutoMixedPrecisionMode::BF16};
1594 GraphDef output;
1595 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1596
1597 VLOG(1) << output.DebugString();
1598
1599 GraphView output_view(&output);
1600 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1601 const char* type_key = "element_dtype";
1602 EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(),
1603 DT_BFLOAT16);
1604 EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(),
1605 DT_BFLOAT16);
1606 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1607 EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(),
1608 DT_BFLOAT16);
1609 EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(),
1610 DT_BFLOAT16);
1611 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_BFLOAT16);
1612 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_BFLOAT16);
1613 EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(),
1614 DT_BFLOAT16);
1615 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
1616 EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
1617 EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
1618
1619 auto tensors = EvaluateNodes(output, item.fetch);
1620 EXPECT_EQ(tensors.size(), tensors_expected.size());
1621 EXPECT_EQ(tensors.size(), item.fetch.size());
1622 for (int i = 0; i < item.fetch.size(); ++i) {
1623 test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-2);
1624 }
1625 }
1626
TEST_F(AutoMixedPrecisionMklTest,InferFollowUpStreamAllow)1627 TEST_F(AutoMixedPrecisionMklTest, InferFollowUpStreamAllow) {
1628 if (!IsMKLEnabled())
1629 GTEST_SKIP() << "Test only applicable to MKL auto-mixed precision.";
1630 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1631 "/job:localhost/replica:0/task:0/device:CPU:0");
1632 Output input1 = ops::Const(s.WithOpName("input1"), 1.f / 32, {8, 56, 56, 16});
1633 Output weight = ops::Const(s.WithOpName("weight"), 2.f, {3, 3, 16, 16});
1634 Output allow =
1635 ops::Conv2D(s.WithOpName("allow"), input1, weight, {1, 1, 1, 1}, "SAME",
1636 ops::Conv2D::DataFormat("NHWC"));
1637 Output input2 = ops::Const(s.WithOpName("input2"), 1.f / 32, {16});
1638 Output infer = ops::BiasAdd(s.WithOpName("infer"), allow, input2);
1639 Output clr = ops::Relu(s.WithOpName("clr"), infer);
1640 Output fetch = ops::Identity(s.WithOpName("fetch"), clr);
1641
1642 GrapplerItem item;
1643 item.fetch = {"fetch"};
1644 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1645 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1646
1647 AutoMixedPrecision optimizer{AutoMixedPrecisionMode::BF16};
1648 GraphDef output;
1649 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1650
1651 VLOG(1) << output.DebugString();
1652
1653 GraphView output_view(&output);
1654 EXPECT_EQ(output.node_size(), item.graph.node_size() + 4);
1655 EXPECT_EQ(output_view.GetNode("input1")->attr().at("dtype").type(), DT_FLOAT);
1656 EXPECT_EQ(output_view.GetNode("weight")->attr().at("dtype").type(), DT_FLOAT);
1657 EXPECT_EQ(output_view.GetNode("input2")->attr().at("dtype").type(), DT_FLOAT);
1658 EXPECT_EQ(output_view.GetNode("allow")->attr().at("T").type(), DT_BFLOAT16);
1659 EXPECT_EQ(output_view.GetNode("infer")->attr().at("T").type(), DT_BFLOAT16);
1660 EXPECT_EQ(output_view.GetNode("clr")->attr().at("T").type(), DT_BFLOAT16);
1661
1662 auto tensors = EvaluateNodes(output, item.fetch);
1663 EXPECT_EQ(tensors.size(), tensors_expected.size());
1664 EXPECT_EQ(tensors.size(), item.fetch.size());
1665 for (int i = 0; i < item.fetch.size(); ++i) {
1666 test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-2);
1667 }
1668 }
1669
TEST_F(AutoMixedPrecisionMklTest,InferFollowUpStreamDeny)1670 TEST_F(AutoMixedPrecisionMklTest, InferFollowUpStreamDeny) {
1671 if (!IsMKLEnabled())
1672 GTEST_SKIP() << "Test only applicable to MKL auto-mixed precision.";
1673 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1674 "/job:localhost/replica:0/task:0/device:CPU:0");
1675 Output input1 = ops::Const(s.WithOpName("input1"), 1.f / 32, {8, 56, 56, 16});
1676 Output input2 = ops::Const(s.WithOpName("input2"), 1.f, {16});
1677 Output input3 = ops::Const(s.WithOpName("input3"), 1.f / 32, {16});
1678 Output deny = ops::Pow(s.WithOpName("deny"), input1, input2);
1679 Output infer = ops::BiasAdd(s.WithOpName("infer"), deny, input3);
1680 Output clr = ops::Relu(s.WithOpName("clr"), infer);
1681 Output fetch = ops::Identity(s.WithOpName("fetch"), clr);
1682
1683 GrapplerItem item;
1684 item.fetch = {"fetch"};
1685 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1686 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1687
1688 AutoMixedPrecision optimizer{AutoMixedPrecisionMode::BF16};
1689 GraphDef output;
1690 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1691
1692 VLOG(1) << output.DebugString();
1693
1694 GraphView output_view(&output);
1695 EXPECT_EQ(output.node_size(), item.graph.node_size());
1696 EXPECT_EQ(output_view.GetNode("input1")->attr().at("dtype").type(), DT_FLOAT);
1697 EXPECT_EQ(output_view.GetNode("input2")->attr().at("dtype").type(), DT_FLOAT);
1698 EXPECT_EQ(output_view.GetNode("input3")->attr().at("dtype").type(), DT_FLOAT);
1699 EXPECT_EQ(output_view.GetNode("deny")->attr().at("T").type(), DT_FLOAT);
1700 EXPECT_EQ(output_view.GetNode("infer")->attr().at("T").type(), DT_FLOAT);
1701 EXPECT_EQ(output_view.GetNode("clr")->attr().at("T").type(), DT_FLOAT);
1702
1703 auto tensors = EvaluateNodes(output, item.fetch);
1704 EXPECT_EQ(tensors.size(), tensors_expected.size());
1705 EXPECT_EQ(tensors.size(), item.fetch.size());
1706 for (int i = 0; i < item.fetch.size(); ++i) {
1707 test::ExpectClose(tensors_expected[i], tensors[i]);
1708 }
1709 }
1710 #endif // INTEL_MKL
1711
1712 } // namespace
1713 } // namespace grappler
1714 } // namespace tensorflow
1715
1716 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || INTEL_MKL
1717