1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || \
17 (INTEL_MKL && defined(ENABLE_INTEL_MKL_BFLOAT16))
18
19 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
20
21 #include <utility>
22 #include <vector>
23
24 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
25 #include "tensorflow/cc/ops/list_ops.h"
26 #include "tensorflow/cc/ops/math_ops.h"
27 #include "tensorflow/cc/ops/standard_ops.h"
28 #include "tensorflow/core/framework/function_testlib.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/tensor_testutil.h"
31 #include "tensorflow/core/grappler/clusters/single_machine.h"
32 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
33 #include "tensorflow/core/grappler/devices.h"
34 #include "tensorflow/core/grappler/graph_view.h"
35 #include "tensorflow/core/grappler/utils/grappler_test.h"
36 #include "tensorflow/core/lib/core/status_test_util.h"
37 #include "tensorflow/core/lib/random/random.h"
38
39 // TODO(benbarsdell): Improve the numerical checks in these tests. The tests
40 // were originally written only to check the graph coloring, so the graphs do
41 // not have particularly realistic numerical behavior.
42
43 namespace tensorflow {
44 namespace grappler {
45 namespace {
46
47 template <DataType DTYPE>
GenerateIdentityMatrix(int64 height,int64 width)48 Tensor GenerateIdentityMatrix(int64 height, int64 width) {
49 typedef typename EnumToDataType<DTYPE>::Type T;
50 Tensor tensor(DTYPE, TensorShape{height, width});
51 for (int64 i = 0; i < height; ++i) {
52 for (int64 j = 0; j < width; ++j) {
53 tensor.matrix<T>()(i, j) = i == j;
54 }
55 }
56 return tensor;
57 }
58
59 template <DataType DTYPE>
GenerateRandomTensorInRange(const TensorShape & shape,double minval,double maxval)60 Tensor GenerateRandomTensorInRange(const TensorShape& shape, double minval,
61 double maxval) {
62 typedef typename EnumToDataType<DTYPE>::Type T;
63 Tensor tensor(DTYPE, shape);
64 for (auto i = 0; i < tensor.NumElements(); i++)
65 tensor.flat<T>()(i) =
66 (random::New64() % 65536 / 65536.0) * (maxval - minval) + minval;
67 return tensor;
68 }
69
VerifyGraphsEquivalent(const GraphDef & original_graph,const GraphDef & optimized_graph,const string & func)70 void VerifyGraphsEquivalent(const GraphDef& original_graph,
71 const GraphDef& optimized_graph,
72 const string& func) {
73 EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
74 GraphView optimized_view(&optimized_graph);
75 for (int i = 0; i < original_graph.node_size(); ++i) {
76 const NodeDef& original = original_graph.node(i);
77 const NodeDef& optimized = *optimized_view.GetNode(original.name());
78 EXPECT_EQ(original.name(), optimized.name()) << func;
79 EXPECT_EQ(original.op(), optimized.op()) << func;
80 EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
81 if (original.input_size() == optimized.input_size()) {
82 for (int j = 0; j < original.input_size(); ++j) {
83 EXPECT_EQ(original.input(j), optimized.input(j)) << func;
84 }
85 }
86 }
87 }
88
89 // Currently, this test suite only passes when TensorFlow passes with CUDA,
90 // because otherwise the optimizer will not turn clearlist nodes to float16.
91 // When looking at clearlist nodes, this optimizer checks if the nodes have a
92 // float16 GPU OpKernel, but without CUDA there are no GPU OpKernels at all.
93 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
94
95 const std::pair<int, int> kMinGPUArch = {7, 0};
96
97 class AutoMixedPrecisionTest : public GrapplerTest {
98 protected:
SetUp()99 void SetUp() override {
100 int num_gpus = GetNumAvailableGPUs();
101 // If GPUs are available, require that they all satisfy the min arch.
102 gpu_available_ = (num_gpus > 0);
103 #if GOOGLE_CUDA
104 gpu_available_ =
105 gpu_available_ && (num_gpus == GetNumAvailableGPUs(kMinGPUArch));
106 #endif
107 if (gpu_available_) {
108 virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 1));
109 } else {
110 DeviceProperties device_properties;
111 device_properties.set_type("GPU");
112 #if GOOGLE_CUDA
113 device_properties.mutable_environment()->insert({"architecture", "7"});
114 device_properties.mutable_environment()->insert({"cuda", "9010"});
115 #endif
116 virtual_cluster_.reset(
117 new VirtualCluster({{"/GPU:1", device_properties}}));
118 }
119 TF_CHECK_OK(virtual_cluster_->Provision());
120 }
121
TearDown()122 void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
123
AddSimpleNode(const string & name,const string & op,const std::vector<string> & inputs,GraphDef * graph) const124 NodeDef* AddSimpleNode(const string& name, const string& op,
125 const std::vector<string>& inputs,
126 GraphDef* graph) const {
127 std::vector<std::pair<string, AttrValue>> attributes;
128 if (op == "AddN" || op == "ShapeN") {
129 AttrValue num_inputs;
130 num_inputs.set_i(inputs.size());
131 attributes.emplace_back("N", num_inputs);
132 }
133 if (op == "ShapeN") {
134 AttrValue out_type;
135 out_type.set_type(DT_INT32);
136 attributes.emplace_back("out_type", out_type);
137 }
138 AttrValue type;
139 type.set_type(DT_FLOAT);
140 if (op == "Const" || op == "Placeholder" || op == "VariableV2" ||
141 op == "VarHandleOp" || op == "ReadVariableOp") {
142 attributes.emplace_back("dtype", type);
143 } else if (op == "SparseMatMul") {
144 attributes.emplace_back("Ta", type);
145 attributes.emplace_back("Tb", type);
146 } else if (op == "IdentityN") {
147 AttrValue type_list;
148 for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
149 type_list.mutable_list()->add_type(DT_FLOAT);
150 }
151 attributes.emplace_back("T", type_list);
152 } else if (op == "StackV2" || op == "StackPopV2") {
153 attributes.emplace_back("elem_type", type);
154 } else if (op == "Cast") {
155 attributes.emplace_back("SrcT", type);
156 attributes.emplace_back("DstT", type);
157 } else {
158 attributes.emplace_back("T", type);
159 }
160 return AddNode(name, op, inputs, attributes, graph);
161 }
162
TestSimpleUnaryInferOp(double input_min,double input_max,double atol,double rtol,const std::function<Output (const tensorflow::Scope &,Output)> & test_op_factory)163 void TestSimpleUnaryInferOp(
164 double input_min, double input_max, double atol, double rtol,
165 const std::function<Output(const tensorflow::Scope&, Output)>&
166 test_op_factory) {
167 int size = 128;
168 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
169 Output eye = ops::Const(s.WithOpName("eye"),
170 GenerateIdentityMatrix<DT_FLOAT>(size, size));
171 Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
172 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, eye);
173 Output infer1 = test_op_factory(s.WithOpName("infer1"), allow1);
174 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, eye);
175 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
176 GrapplerItem item;
177 item.fetch = {"fetch1"};
178 TF_CHECK_OK(s.ToGraphDef(&item.graph));
179 auto input_tensor = GenerateRandomTensorInRange<DT_FLOAT>(
180 TensorShape({size, size}), input_min, input_max);
181 std::vector<std::pair<string, Tensor>> feed = {{"input", input_tensor}};
182 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
183
184 AutoMixedPrecision optimizer;
185 GraphDef output;
186 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
187
188 VLOG(1) << output.DebugString();
189
190 GraphView output_view(&output);
191 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(),
192 DT_FLOAT);
193 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
194 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
195 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
196
197 auto tensors = EvaluateNodes(output, item.fetch, feed);
198 EXPECT_EQ(tensors.size(), tensors_expected.size());
199 EXPECT_EQ(tensors.size(), item.fetch.size());
200 for (int i = 0; i < item.fetch.size(); ++i) {
201 test::ExpectClose(tensors_expected[i], tensors[i], atol, rtol);
202 }
203 }
204
205 std::unique_ptr<Cluster> virtual_cluster_;
206 bool gpu_available_;
207 };
208
TEST_F(AutoMixedPrecisionTest,NoOp)209 TEST_F(AutoMixedPrecisionTest, NoOp) {
210 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
211 Output input = ops::Const(s.WithOpName("input"), 1.234f, {32});
212 Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
213 Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
214 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
215 Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
216 Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
217
218 GrapplerItem item;
219 item.fetch = {"fetch"};
220 TF_CHECK_OK(s.ToGraphDef(&item.graph));
221 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
222
223 AutoMixedPrecision optimizer;
224 GraphDef output;
225 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
226
227 VLOG(1) << output.DebugString();
228
229 VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
230
231 GraphView output_view(&output);
232 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
233 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
234 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
235 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
236 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
237
238 auto tensors = EvaluateNodes(output, item.fetch);
239 EXPECT_EQ(tensors.size(), tensors_expected.size());
240 EXPECT_EQ(tensors.size(), item.fetch.size());
241 for (int i = 0; i < item.fetch.size(); ++i) {
242 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
243 }
244 }
245
TEST_F(AutoMixedPrecisionTest,AlreadyFp16)246 TEST_F(AutoMixedPrecisionTest, AlreadyFp16) {
247 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
248 Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
249 Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_HALF);
250 Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
251 Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
252 Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
253 Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
254 Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
255
256 GrapplerItem item;
257 item.fetch = {"fetch"};
258 TF_CHECK_OK(s.ToGraphDef(&item.graph));
259 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
260
261 AutoMixedPrecision optimizer;
262 GraphDef output;
263 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
264 VLOG(1) << output.DebugString();
265
266 VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
267 GraphView output_view(&output);
268 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
269 EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
270 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
271 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
272 EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_HALF);
273 EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
274 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
275
276 auto tensors = EvaluateNodes(output, item.fetch);
277 EXPECT_EQ(tensors.size(), tensors_expected.size());
278 EXPECT_EQ(tensors.size(), item.fetch.size());
279 for (int i = 0; i < item.fetch.size(); ++i) {
280 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
281 }
282 }
283
TEST_F(AutoMixedPrecisionTest,Simple)284 TEST_F(AutoMixedPrecisionTest, Simple) {
285 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
286 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
287 Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
288 Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
289 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
290 Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
291 Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
292 Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
293 Output infer2 = ops::Log(s.WithOpName("infer2"), clr3);
294 Output clr4 = ops::Relu(s.WithOpName("clr4"), infer2);
295 Output deny2 = ops::SparseMatMul(s.WithOpName("deny2"), clr4, clr4);
296 Output clr5 = ops::Relu(s.WithOpName("clr5"), deny2);
297 Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
298
299 GrapplerItem item;
300 item.fetch = {"fetch"};
301 TF_CHECK_OK(s.ToGraphDef(&item.graph));
302 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
303
304 AutoMixedPrecision optimizer;
305 GraphDef output;
306 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
307
308 VLOG(1) << output.DebugString();
309
310 GraphView output_view(&output);
311 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
312 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
313 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
314 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
315 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
316 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
317 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
318 EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
319 EXPECT_EQ(output_view.GetNode("infer2")->attr().at("T").type(), DT_FLOAT);
320 EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
321 EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Ta").type(), DT_FLOAT);
322 EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Tb").type(), DT_FLOAT);
323 EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
324
325 auto tensors = EvaluateNodes(output, item.fetch);
326 EXPECT_EQ(tensors.size(), tensors_expected.size());
327 EXPECT_EQ(tensors.size(), item.fetch.size());
328 for (int i = 0; i < item.fetch.size(); ++i) {
329 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
330 }
331 }
332
TEST_F(AutoMixedPrecisionTest,BidirectionalClearChain)333 TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) {
334 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
335 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
336 Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
337 Output clr2 = ops::Relu(s.WithOpName("clr2"), input);
338 Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
339 auto clr3 = ops::ShapeN(s.WithOpName("clr3"), {clr1, clr2});
340 Output clr4 = ops::Relu(s.WithOpName("clr4"), clr2);
341 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
342 Output fetch2 = ops::Identity(s.WithOpName("fetch2"), clr4);
343
344 GrapplerItem item;
345 item.fetch = {"fetch1", "fetch2"};
346 TF_CHECK_OK(s.ToGraphDef(&item.graph));
347 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
348
349 AutoMixedPrecision optimizer;
350 GraphDef output;
351 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
352
353 VLOG(1) << output.DebugString();
354
355 GraphView output_view(&output);
356 EXPECT_EQ(output.node_size(), item.graph.node_size() + 3);
357 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
358 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
359 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
360 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
361 EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
362 EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_HALF);
363
364 auto tensors = EvaluateNodes(output, item.fetch);
365 EXPECT_EQ(tensors.size(), tensors_expected.size());
366 EXPECT_EQ(tensors.size(), item.fetch.size());
367 for (int i = 0; i < item.fetch.size(); ++i) {
368 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
369 }
370 }
371
TEST_F(AutoMixedPrecisionTest,PreserveFetches)372 TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
373 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
374 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
375 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
376 Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
377 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
378 Output deny1 = ops::Exp(s.WithOpName("deny1"), infer1);
379 Output clr2 = ops::Relu(s.WithOpName("clr2"), deny1);
380 Output allow2 = ops::MatMul(s.WithOpName("allow2"), clr2, clr2);
381 Output clr3 = ops::Relu(s.WithOpName("clr3"), allow2);
382 Output deny2 = ops::Exp(s.WithOpName("deny2"), clr3);
383 Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
384
385 GrapplerItem item;
386 item.fetch = {"allow1", "clr2", "clr3"};
387 TF_CHECK_OK(s.ToGraphDef(&item.graph));
388 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
389
390 AutoMixedPrecision optimizer;
391 GraphDef output;
392 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
393
394 VLOG(1) << output.DebugString();
395
396 GraphView output_view(&output);
397 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
398 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
399 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
400 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
401 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
402 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
403 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
404 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
405 EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_FLOAT);
406 EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
407 EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
408
409 auto tensors = EvaluateNodes(output, item.fetch);
410 EXPECT_EQ(tensors.size(), tensors_expected.size());
411 EXPECT_EQ(tensors.size(), item.fetch.size());
412 for (int i = 0; i < item.fetch.size(); ++i) {
413 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-3);
414 }
415 }
416
TEST_F(AutoMixedPrecisionTest,PreserveCPUNodes)417 TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) {
418 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
419 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
420 Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
421 Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
422 Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
423 Output allow2 =
424 ops::MatMul(s.WithOpName("allow2").WithDevice(
425 "/job:localhost/replica:0/task:0/device:CPU:0"),
426 infer1, infer1);
427 Output clr2 = ops::Relu(s.WithOpName("clr2"), allow2);
428 Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
429
430 GrapplerItem item;
431 item.fetch = {"fetch"};
432 TF_CHECK_OK(s.ToGraphDef(&item.graph));
433 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
434
435 AutoMixedPrecision optimizer;
436 GraphDef output;
437 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
438
439 VLOG(1) << output.DebugString();
440
441 GraphView output_view(&output);
442 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
443 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
444 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
445 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
446 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
447 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_FLOAT);
448 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
449
450 auto tensors = EvaluateNodes(output, item.fetch);
451 EXPECT_EQ(tensors.size(), tensors_expected.size());
452 EXPECT_EQ(tensors.size(), item.fetch.size());
453 for (int i = 0; i < item.fetch.size(); ++i) {
454 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
455 }
456 }
457
TEST_F(AutoMixedPrecisionTest,PreserveIdentityAfterVariable)458 TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) {
459 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
460 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
461 Output var1 = ops::Variable(s.WithOpName("var1"), {32, 32}, DT_FLOAT);
462 Output clr1 = ops::Identity(s.WithOpName("clr1"), var1);
463 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, clr1);
464 Output input2 = ops::Const(s.WithOpName("input2"), 1.f / 32, {32, 32});
465 Output clr2 = ops::Identity(s.WithOpName("clr2"), input2);
466 Output allow2 = ops::MatMul(s.WithOpName("allow2"), input, clr2);
467 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
468 Output fetch2 = ops::Identity(s.WithOpName("fetch2"), allow2);
469
470 GrapplerItem item;
471 item.fetch = {"fetch1", "fetch2"};
472 TF_CHECK_OK(s.ToGraphDef(&item.graph));
473 auto var1_tensor =
474 GenerateConstantTensor<DT_FLOAT>(TensorShape({32, 32}), 3.141593f);
475 std::vector<std::pair<string, Tensor>> feed = {{"var1", var1_tensor}};
476 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
477
478 AutoMixedPrecision optimizer;
479 GraphDef output;
480 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
481
482 VLOG(1) << output.DebugString();
483
484 GraphView output_view(&output);
485 EXPECT_EQ(output.node_size(), item.graph.node_size() + 5);
486 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
487 EXPECT_EQ(output_view.GetNode("var1")->attr().at("dtype").type(), DT_FLOAT);
488 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
489 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
490 EXPECT_EQ(output_view.GetNode("input2")->attr().at("dtype").type(), DT_FLOAT);
491 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
492 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
493
494 auto tensors = EvaluateNodes(output, item.fetch, feed);
495 EXPECT_EQ(tensors.size(), tensors_expected.size());
496 EXPECT_EQ(tensors.size(), item.fetch.size());
497 for (int i = 0; i < item.fetch.size(); ++i) {
498 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-3);
499 }
500 }
501
TEST_F(AutoMixedPrecisionTest,FusedBatchNorm)502 TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
503 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
504 // Uses NHWC data format because non-GPU execution does not support NCHW.
505 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {8, 56, 56, 16});
506 Output weight = ops::Const(s.WithOpName("weight"), 2.f, {3, 3, 16, 16});
507 Output scale = ops::Const(s.WithOpName("scale"), 3.f, {16});
508 Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16});
509 Output mean = ops::Const(s.WithOpName("mean"), 5.f, {0});
510 Output variance = ops::Const(s.WithOpName("variance"), 6.f, {0});
511 Output allow1 =
512 ops::Conv2D(s.WithOpName("allow1"), input, weight, {1, 1, 1, 1}, "SAME",
513 ops::Conv2D::DataFormat("NHWC"));
514 auto fbn1_op =
515 ops::FusedBatchNorm(s.WithOpName("fbn1"), allow1, scale, offset, mean,
516 variance, ops::FusedBatchNorm::DataFormat("NHWC"));
517 Output fbn1 = fbn1_op.y;
518 Output fbn1_rs1 = fbn1_op.reserve_space_1;
519 Output fbn1_rs2 = fbn1_op.reserve_space_2;
520 Output bng1 = ops::FusedBatchNormGrad(
521 s.WithOpName("bng1"), fbn1, allow1, scale, fbn1_rs1,
522 fbn1_rs2, ops::FusedBatchNormGrad::DataFormat("NHWC"))
523 .x_backprop;
524 Output infer1 = ops::Add(s.WithOpName("infer1"), fbn1, bng1);
525 Output allow2 =
526 ops::Conv2D(s.WithOpName("allow2"), infer1, weight, {1, 1, 1, 1}, "SAME",
527 ops::Conv2D::DataFormat("NHWC"));
528 Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
529
530 GrapplerItem item;
531 item.fetch = {"fetch"};
532 TF_CHECK_OK(s.ToGraphDef(&item.graph));
533 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
534
535 AutoMixedPrecision optimizer;
536 GraphDef output;
537 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
538
539 VLOG(1) << output.DebugString();
540
541 GraphView output_view(&output);
542 EXPECT_EQ(output.node_size(), item.graph.node_size() + 3);
543 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
544 EXPECT_EQ(output_view.GetNode("fbn1")->op(), "FusedBatchNormV2");
545 EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("T").type(), DT_HALF);
546 EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("U").type(), DT_FLOAT);
547 EXPECT_EQ(output_view.GetNode("bng1")->op(), "FusedBatchNormGradV2");
548 EXPECT_EQ(output_view.GetNode("bng1")->attr().at("T").type(), DT_HALF);
549 EXPECT_EQ(output_view.GetNode("bng1")->attr().at("U").type(), DT_FLOAT);
550 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
551 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
552
553 auto tensors = EvaluateNodes(output, item.fetch);
554 EXPECT_EQ(tensors.size(), tensors_expected.size());
555 EXPECT_EQ(tensors.size(), item.fetch.size());
556 for (int i = 0; i < item.fetch.size(); ++i) {
557 test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-3);
558 }
559 }
560
TEST_F(AutoMixedPrecisionTest,RepeatedAndListTypeAttrs)561 TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
562 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
563 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
564 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
565 auto clr1_op = ops::IdentityN(s.WithOpName("clr1"), {allow1, allow1, allow1});
566 Output infer1 =
567 ops::AddN(s.WithOpName("infer1"),
568 {clr1_op.output[0], clr1_op.output[1], clr1_op.output[2]});
569 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
570 Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
571
572 GrapplerItem item;
573 item.fetch = {"fetch"};
574 TF_CHECK_OK(s.ToGraphDef(&item.graph));
575 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
576
577 AutoMixedPrecision optimizer;
578 GraphDef output;
579 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
580
581 VLOG(1) << output.DebugString();
582
583 GraphView output_view(&output);
584 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
585 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
586 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
587 for (auto type : output_view.GetNode("clr1")->attr().at("T").list().type()) {
588 EXPECT_EQ(type, DT_HALF);
589 }
590 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
591 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
592
593 auto tensors = EvaluateNodes(output, item.fetch);
594 EXPECT_EQ(tensors.size(), tensors_expected.size());
595 EXPECT_EQ(tensors.size(), item.fetch.size());
596 for (int i = 0; i < item.fetch.size(); ++i) {
597 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
598 }
599 }
600
TEST_F(AutoMixedPrecisionTest,ExistingCast)601 TEST_F(AutoMixedPrecisionTest, ExistingCast) {
602 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
603 Output input = ops::Const(s.WithOpName("input"), true, {32, 32});
604 Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_FLOAT);
605 Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
606 Output fetch = ops::Identity(s.WithOpName("fetch"), allow1);
607
608 GrapplerItem item;
609 item.fetch = {"fetch"};
610 TF_CHECK_OK(s.ToGraphDef(&item.graph));
611 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
612
613 AutoMixedPrecision optimizer;
614 GraphDef output;
615 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
616
617 VLOG(1) << output.DebugString();
618
619 GraphView output_view(&output);
620 EXPECT_EQ(output.node_size(), item.graph.node_size() + 1);
621 EXPECT_EQ(output_view.GetNode("cst1")->attr().at("SrcT").type(), DT_BOOL);
622 EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
623 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
624
625 auto tensors = EvaluateNodes(output, item.fetch);
626 EXPECT_EQ(tensors.size(), tensors_expected.size());
627 EXPECT_EQ(tensors.size(), item.fetch.size());
628 for (int i = 0; i < item.fetch.size(); ++i) {
629 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
630 }
631 }
632
TEST_F(AutoMixedPrecisionTest,RecurrentEdgeColorMismatch)633 TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) {
634 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
635 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
636 Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
637 Output ent1 =
638 ops::internal::Enter(s.WithOpName("ent1"), deny1, "loop1").output;
639 // Note that the second input is later replaced with "nxt1".
640 Output mrg1 = ops::Merge(s.WithOpName("mrg1"), {ent1, ent1}).output;
641 // For simplicity, the loop condition is constant false.
642 Output con1 = ops::Const(s.WithOpName("con1"), false, {});
643 Output lpc1 = ops::LoopCond(s.WithOpName("lpc1"), con1).output;
644 auto swt1 = ops::Switch(s.WithOpName("swt1"), mrg1, lpc1);
645 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), swt1.output_true);
646 Output allow1 = ops::MatMul(s.WithOpName("allow1"), infer1, infer1);
647 Output nxt1 = ops::NextIteration(s.WithOpName("nxt1"), allow1);
648 Output ext1 = ops::internal::Exit(s.WithOpName("ext1"), swt1.output_false);
649 Output fetch = ops::Identity(s.WithOpName("fetch"), ext1);
650 // Add a second merge node from the same NextIteration node. This case arises
651 // during graph optimization of some models.
652 auto mrg2 = ops::Merge(s.WithOpName("mrg2"), {ent1, nxt1});
653
654 GrapplerItem item;
655 item.fetch = {"fetch"};
656 TF_CHECK_OK(s.ToGraphDef(&item.graph));
657 NodeMap node_map_original(&item.graph);
658 auto merge_node = node_map_original.GetNode("mrg1");
659 // Modify the graph to create a loop.
660 merge_node->set_input(1, "nxt1");
661 // Add a control edge to ensure the loop condition is inside the frame.
662 auto const_node = node_map_original.GetNode("con1");
663 const_node->add_input("^mrg1");
664 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
665
666 AutoMixedPrecision optimizer;
667 GraphDef output;
668 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
669
670 VLOG(1) << output.DebugString();
671
672 GraphView output_view(&output);
673 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
674 // Note that mrg1 gets painted deny because it is between deny1 and infer1.
675 // This forces nxt1 and mrg2 to be painted deny as well (they would otherwise
676 // be painted allow because they are clear and have a direct path to allow1).
677 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
678 EXPECT_EQ(output_view.GetNode("ent1")->attr().at("T").type(), DT_FLOAT);
679 EXPECT_EQ(output_view.GetNode("mrg1")->attr().at("T").type(), DT_FLOAT);
680 EXPECT_EQ(output_view.GetNode("swt1")->attr().at("T").type(), DT_FLOAT);
681 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
682 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
683 EXPECT_EQ(output_view.GetNode("nxt1")->attr().at("T").type(), DT_FLOAT);
684 EXPECT_EQ(output_view.GetNode("ext1")->attr().at("T").type(), DT_FLOAT);
685 EXPECT_EQ(output_view.GetNode("mrg2")->attr().at("T").type(), DT_FLOAT);
686
687 auto tensors = EvaluateNodes(output, item.fetch);
688 EXPECT_EQ(tensors.size(), tensors_expected.size());
689 EXPECT_EQ(tensors.size(), item.fetch.size());
690 for (int i = 0; i < item.fetch.size(); ++i) {
691 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
692 }
693 }
694
TEST_F(AutoMixedPrecisionTest,TensorListSetGet)695 TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
696 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
697 tensorflow::Input shape = {32, 32};
698 auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
699 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
700 Output idx1 = ops::Const(s.WithOpName("idx1"), 1);
701 Output idx2 = ops::Const(s.WithOpName("idx2"), 2);
702 Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
703 auto tl1w1 =
704 ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
705 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
706 auto tl1w2 =
707 ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
708 // Ensure that TensorListResize doesn't cause any problems.
709 Output tl1rs =
710 ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
711 Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
712 shape, DT_FLOAT)
713 .item;
714 Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
715 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
716 auto tl1w3 =
717 ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
718 Output tl1r2 =
719 ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
720 shape, DT_FLOAT)
721 .item;
722 auto tl2 = ops::TensorListReserve(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
723 auto tl2w1 =
724 ops::TensorListSetItem(s.WithOpName("tl2w1"), tl2.handle, idx1, input);
725 Output tl2r1 =
726 ops::TensorListGetItem(s.WithOpName("tl2r1"), tl2w1.output_handle, idx1,
727 shape, DT_FLOAT)
728 .item;
729 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
730 Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
731
732 GrapplerItem item;
733 item.fetch = {"fetch1", "fetch2"};
734 TF_CHECK_OK(s.ToGraphDef(&item.graph));
735 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
736
737 AutoMixedPrecision optimizer;
738 GraphDef output;
739 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
740
741 VLOG(1) << output.DebugString();
742
743 GraphView output_view(&output);
744 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
745 const char* type_key = "element_dtype";
746 EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
747 EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
748 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
749 EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
750 EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
751 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
752 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
753 EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
754 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
755 EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
756 EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
757
758 auto tensors = EvaluateNodes(output, item.fetch);
759 EXPECT_EQ(tensors.size(), tensors_expected.size());
760 EXPECT_EQ(tensors.size(), item.fetch.size());
761 for (int i = 0; i < item.fetch.size(); ++i) {
762 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
763 }
764 }
765
TEST_F(AutoMixedPrecisionTest,TensorListPushPop)766 TEST_F(AutoMixedPrecisionTest, TensorListPushPop) {
767 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
768 tensorflow::Input shape = {32, 32};
769 auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
770 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
771 auto tl1w1 =
772 ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, input);
773 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
774 auto tl1w2 = ops::TensorListPushBack(s.WithOpName("tl1w2"),
775 tl1w1.output_handle, allow1);
776 Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"),
777 tl1w2.output_handle, shape, DT_FLOAT)
778 .tensor;
779 Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
780 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
781 auto tl1w3 =
782 ops::TensorListPushBack(s.WithOpName("tl1w3"), tl1.handle, allow2);
783 Output tl1r2 = ops::TensorListPopBack(s.WithOpName("tl1r2"),
784 tl1w3.output_handle, shape, DT_FLOAT)
785 .tensor;
786 auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
787 auto tl2w1 =
788 ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, input);
789 Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
790 tl2w1.output_handle, shape, DT_FLOAT)
791 .tensor;
792 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
793 Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
794
795 GrapplerItem item;
796 item.fetch = {"fetch1", "fetch2"};
797 TF_CHECK_OK(s.ToGraphDef(&item.graph));
798 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
799
800 AutoMixedPrecision optimizer;
801 GraphDef output;
802 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
803
804 VLOG(1) << output.DebugString();
805
806 GraphView output_view(&output);
807 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
808 const char* type_key = "element_dtype";
809 EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
810 EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
811 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
812 EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
813 EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
814 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
815 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
816 EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
817 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
818 EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
819 EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
820
821 auto tensors = EvaluateNodes(output, item.fetch);
822 EXPECT_EQ(tensors.size(), tensors_expected.size());
823 EXPECT_EQ(tensors.size(), item.fetch.size());
824 for (int i = 0; i < item.fetch.size(); ++i) {
825 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
826 }
827 }
828
TEST_F(AutoMixedPrecisionTest,TensorListFromTensor)829 TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) {
830 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
831 tensorflow::Input shape = {32};
832 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
833 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
834 auto tl1 = ops::TensorListFromTensor(s.WithOpName("tl1"), allow1, shape);
835 Output tl1r1 = ops::TensorListStack(s.WithOpName("tl1r1"), tl1.output_handle,
836 shape, DT_FLOAT)
837 .tensor;
838 Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
839 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
840 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
841
842 // This tests that a allow-painted object node (tl2) will force an unpainted
843 // client node (tl2w1) to be painted allow as well. (Without the force, tl2w1
844 // would remain unpainted, producing an invalid graph).
845 auto tl2 = ops::TensorListFromTensor(s.WithOpName("tl2"), allow1, shape);
846 auto tl2w1 =
847 ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.output_handle, input);
848
849 GrapplerItem item;
850 item.fetch = {"fetch1"};
851 TF_CHECK_OK(s.ToGraphDef(&item.graph));
852 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
853
854 AutoMixedPrecision optimizer;
855 GraphDef output;
856 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
857
858 VLOG(1) << output.DebugString();
859
860 GraphView output_view(&output);
861 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
862 const char* type_key = "element_dtype";
863 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
864 EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
865 EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
866 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
867 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
868 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
869 EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
870
871 auto tensors = EvaluateNodes(output, item.fetch);
872 EXPECT_EQ(tensors.size(), tensors_expected.size());
873 EXPECT_EQ(tensors.size(), item.fetch.size());
874 for (int i = 0; i < item.fetch.size(); ++i) {
875 test::ExpectClose(tensors_expected[i], tensors[i], -1, 2e-4);
876 }
877 }
878
TEST_F(AutoMixedPrecisionTest,TensorListPushBackBatchAndConcatLists)879 TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
880 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
881 tensorflow::Input shape = {32, 32};
882 auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
883 auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
884 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
885 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
886 Output tl1_tl2 =
887 ops::Stack(s.WithOpName("tl1_tl2"), {tl1.handle, tl2.handle});
888 Output allow1_allow1 =
889 ops::Stack(s.WithOpName("allow1_allow1"), {allow1, allow1});
890 auto tl12w1 = ops::TensorListPushBackBatch(s.WithOpName("tl12w1"), tl1_tl2,
891 allow1_allow1);
892 OutputList tl12w1_outputs =
893 ops::Split(s.WithOpName("tl12w1_outputs"), 0, tl12w1.output_handles, 2)
894 .output;
895 Output scalar_shape = ops::Const(s.WithOpName("scalar_shape"), 0, {0});
896 Output tl12w1_output0 = ops::Reshape(s.WithOpName("tl12w1_output0"),
897 tl12w1_outputs[0], scalar_shape);
898 Output tl12w1_output1 = ops::Reshape(s.WithOpName("tl12w1_output1"),
899 tl12w1_outputs[1], scalar_shape);
900 Output tl3 = ops::TensorListConcatLists(s.WithOpName("tl3"), tl12w1_output0,
901 tl12w1_output1, DT_FLOAT);
902 Output tl3r1 =
903 ops::TensorListPopBack(s.WithOpName("tl3r1"), tl3, shape, DT_FLOAT)
904 .tensor;
905 Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl3r1);
906 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
907 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
908
909 GrapplerItem item;
910 item.fetch = {"fetch1"};
911 TF_CHECK_OK(s.ToGraphDef(&item.graph));
912 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
913
914 AutoMixedPrecision optimizer;
915 GraphDef output;
916 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
917
918 VLOG(1) << output.DebugString();
919
920 GraphView output_view(&output);
921 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
922 const char* type_key = "element_dtype";
923 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
924 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
925 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
926 EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
927 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
928 EXPECT_EQ(output_view.GetNode("tl3")->attr().at(type_key).type(), DT_HALF);
929 EXPECT_EQ(output_view.GetNode("tl3r1")->attr().at(type_key).type(), DT_HALF);
930
931 auto tensors = EvaluateNodes(output, item.fetch);
932 EXPECT_EQ(tensors.size(), tensors_expected.size());
933 EXPECT_EQ(tensors.size(), item.fetch.size());
934 for (int i = 0; i < item.fetch.size(); ++i) {
935 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
936 }
937 }
938
TEST_F(AutoMixedPrecisionTest,TensorListThroughFunction)939 TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
940 // This test passes a tensor list handle through a function with its own
941 // Tensor List ops inside to test that the types are not changed to a
942 // conflicting state.
943 // A separate Tensor List cluster is added to test that it is still changed to
944 // DT_HALF.
945 FunctionDefLibrary function_lib;
946 const Tensor kShape = test::AsTensor<int32>({32, 32});
947 FunctionDef func1 = FunctionDefHelper::Define(
948 "Func1", {"ihandle: variant", "x: float"},
949 {"ohandle: variant", "y: float"}, {},
950 {
951 {{"tl1w1_handle"},
952 "TensorListPushBack",
953 {"ihandle", "x"},
954 {{"element_dtype", DT_FLOAT}}},
955 {{"shape"}, "Const", {}, {{"value", kShape}, {"dtype", DT_INT32}}},
956 {{"tl1r1_handle", "tl1r1_data"},
957 "TensorListPopBack",
958 {"tl1w1_handle", "shape"},
959 {{"element_dtype", DT_FLOAT}}},
960 {{"ohandle"}, "Identity", {"tl1r1_handle"}, {{"T", DT_VARIANT}}},
961 {{"y"}, "Identity", {"tl1r1_data"}, {{"T", DT_FLOAT}}},
962 });
963 function_lib.add_function()->Swap(&func1);
964
965 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
966 TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib));
967 tensorflow::Input shape = {32, 32};
968 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
969 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
970 Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
971 auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
972 auto tl1w1 =
973 ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, infer1);
974 auto _infer1 = tensorflow::ops::AsNodeOut(s, infer1);
975 auto _tl1w1_handle = tensorflow::ops::AsNodeOut(s, tl1w1.output_handle);
976 auto builder =
977 tensorflow::NodeBuilder("Func1", "Func1", s.graph()->op_registry());
978 tensorflow::Node* func1_op;
979 TF_CHECK_OK(builder.Input(_tl1w1_handle)
980 .Input(_infer1)
981 .Finalize(s.graph(), &func1_op));
982 Output func1_handle(func1_op, 0);
983 Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"), func1_handle,
984 shape, DT_FLOAT)
985 .tensor;
986 auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
987 auto tl2w1 =
988 ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, infer1);
989 Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
990 tl2w1.output_handle, shape, DT_FLOAT)
991 .tensor;
992 Output allow2 = ops::MatMul(s.WithOpName("allow2"), tl1r1, tl2r1);
993 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
994
995 GrapplerItem item;
996 item.fetch = {"fetch1"};
997 TF_CHECK_OK(s.ToGraphDef(&item.graph));
998 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
999
1000 AutoMixedPrecision optimizer;
1001 GraphDef output;
1002 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1003
1004 VLOG(1) << output.DebugString();
1005
1006 GraphView output_view(&output);
1007 const char* type_key = "element_dtype";
1008 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
1009 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
1010 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
1011 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
1012 EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
1013 EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_HALF);
1014
1015 auto tensors = EvaluateNodes(output, item.fetch);
1016 EXPECT_EQ(tensors.size(), tensors_expected.size());
1017 EXPECT_EQ(tensors.size(), item.fetch.size());
1018 for (int i = 0; i < item.fetch.size(); ++i) {
1019 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
1020 }
1021 }
1022
GetCudaVersion(const Cluster & cluster)1023 int GetCudaVersion(const Cluster& cluster) {
1024 auto devices = cluster.GetDevices();
1025 for (const auto& device : devices) {
1026 const DeviceProperties& device_properties = device.second;
1027 if (device_properties.type() == "GPU") {
1028 const auto& device_env = device_properties.environment();
1029 auto it = device_env.find("cuda");
1030 if (it != device_env.end()) {
1031 string cuda_version_str = it->second;
1032 return std::stoi(cuda_version_str);
1033 }
1034 }
1035 }
1036 return 0;
1037 }
1038
TEST_F(AutoMixedPrecisionTest,BatchMatMul)1039 TEST_F(AutoMixedPrecisionTest, BatchMatMul) {
1040 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1041 Output input = ops::Const(s.WithOpName("input"), 1.f / 33, {64, 32, 32});
1042 Output allow1 = ops::BatchMatMul(s.WithOpName("allow1"), input, input);
1043 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
1044
1045 GrapplerItem item;
1046 item.fetch = {"fetch1"};
1047 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1048 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1049
1050 AutoMixedPrecision optimizer;
1051 GraphDef output;
1052 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1053
1054 VLOG(1) << output.DebugString();
1055
1056 GraphView output_view(&output);
1057 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1058 if (GetCudaVersion(*virtual_cluster_.get()) >= 9010) {
1059 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1060 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
1061 } else {
1062 EXPECT_EQ(output.node_size(), item.graph.node_size());
1063 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
1064 }
1065
1066 auto tensors = EvaluateNodes(output, item.fetch);
1067 EXPECT_EQ(tensors.size(), tensors_expected.size());
1068 EXPECT_EQ(tensors.size(), item.fetch.size());
1069 for (int i = 0; i < item.fetch.size(); ++i) {
1070 test::ExpectClose(tensors_expected[i], tensors[i], -1, 3.0e-3);
1071 }
1072 }
1073
TEST_F(AutoMixedPrecisionTest,EluOp)1074 TEST_F(AutoMixedPrecisionTest, EluOp) {
1075 TestSimpleUnaryInferOp(
1076 -5, 5, 1.0e-3, 1.0e-3,
1077 [](const tensorflow::Scope& scope, Output input) -> Output {
1078 return ops::Elu(scope, input);
1079 });
1080 }
1081
TEST_F(AutoMixedPrecisionTest,ErfOp)1082 TEST_F(AutoMixedPrecisionTest, ErfOp) {
1083 TestSimpleUnaryInferOp(
1084 -5, 5, 1.0e-3, -1,
1085 [](const tensorflow::Scope& scope, Output input) -> Output {
1086 return ops::Erf(scope, input);
1087 });
1088 }
1089
TEST_F(AutoMixedPrecisionTest,ErfcOp)1090 TEST_F(AutoMixedPrecisionTest, ErfcOp) {
1091 TestSimpleUnaryInferOp(
1092 -5, 5, 1.0e-3, -1,
1093 [](const tensorflow::Scope& scope, Output input) -> Output {
1094 return ops::Erfc(scope, input);
1095 });
1096 }
1097
TEST_F(AutoMixedPrecisionTest,InvOp)1098 TEST_F(AutoMixedPrecisionTest, InvOp) {
1099 TestSimpleUnaryInferOp(
1100 0.01, 10, -1, 1.0e-3,
1101 [](const tensorflow::Scope& scope, Output input) -> Output {
1102 return ops::Inv(scope, input);
1103 });
1104 }
1105
TEST_F(AutoMixedPrecisionTest,LogOp)1106 TEST_F(AutoMixedPrecisionTest, LogOp) {
1107 TestSimpleUnaryInferOp(
1108 0.01, 10, 1.0e-3, 2.0e-3,
1109 [](const tensorflow::Scope& scope, Output input) -> Output {
1110 return ops::Log(scope, input);
1111 });
1112 }
1113
TEST_F(AutoMixedPrecisionTest,Log1pOp)1114 TEST_F(AutoMixedPrecisionTest, Log1pOp) {
1115 TestSimpleUnaryInferOp(
1116 -0.99, 9, 1.0e-3, 5.0e-3,
1117 [](const tensorflow::Scope& scope, Output input) -> Output {
1118 return ops::Log1p(scope, input);
1119 });
1120 }
1121
TEST_F(AutoMixedPrecisionTest,LogSoftmaxOp)1122 TEST_F(AutoMixedPrecisionTest, LogSoftmaxOp) {
1123 TestSimpleUnaryInferOp(
1124 -8, 8, -1, 1.0e-2,
1125 [](const tensorflow::Scope& scope, Output input) -> Output {
1126 return ops::LogSoftmax(scope, input);
1127 });
1128 }
1129
TEST_F(AutoMixedPrecisionTest,ReciprocalOp)1130 TEST_F(AutoMixedPrecisionTest, ReciprocalOp) {
1131 TestSimpleUnaryInferOp(
1132 0.01, 10, -1, 1.0e-3,
1133 [](const tensorflow::Scope& scope, Output input) -> Output {
1134 return ops::Reciprocal(scope, input);
1135 });
1136 }
1137
TEST_F(AutoMixedPrecisionTest,SigmoidOp)1138 TEST_F(AutoMixedPrecisionTest, SigmoidOp) {
1139 TestSimpleUnaryInferOp(
1140 -5, 5, 1.0e-3, -1,
1141 [](const tensorflow::Scope& scope, Output input) -> Output {
1142 return ops::Sigmoid(scope, input);
1143 });
1144 }
1145
TEST_F(AutoMixedPrecisionTest,SoftmaxOp)1146 TEST_F(AutoMixedPrecisionTest, SoftmaxOp) {
1147 TestSimpleUnaryInferOp(
1148 -8, 8, 2.0e-3, -1,
1149 [](const tensorflow::Scope& scope, Output input) -> Output {
1150 return ops::Softmax(scope, input);
1151 });
1152 }
1153
TEST_F(AutoMixedPrecisionTest,SoftplusOp)1154 TEST_F(AutoMixedPrecisionTest, SoftplusOp) {
1155 TestSimpleUnaryInferOp(
1156 -5, 5, 1.0e-3, 1.0e-3,
1157 [](const tensorflow::Scope& scope, Output input) -> Output {
1158 return ops::Softplus(scope, input);
1159 });
1160 }
1161
TEST_F(AutoMixedPrecisionTest,SqrtOp)1162 TEST_F(AutoMixedPrecisionTest, SqrtOp) {
1163 TestSimpleUnaryInferOp(
1164 0, 10, 1.0e-3, 1.0e-3,
1165 [](const tensorflow::Scope& scope, Output input) -> Output {
1166 return ops::Sqrt(scope, input);
1167 });
1168 }
1169
TEST_F(AutoMixedPrecisionTest,TanhOp)1170 TEST_F(AutoMixedPrecisionTest, TanhOp) {
1171 TestSimpleUnaryInferOp(
1172 -5, 5, 1.0e-3, -1,
1173 [](const tensorflow::Scope& scope, Output input) -> Output {
1174 return ops::Tanh(scope, input);
1175 });
1176 }
1177
1178 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1179
1180 #if INTEL_MKL
1181 #ifdef ENABLE_INTEL_MKL_BFLOAT16
1182
1183 class AutoMixedPrecisionMklTest : public GrapplerTest {
1184 protected:
SetUp()1185 void SetUp() override {
1186 virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0));
1187 TF_CHECK_OK(virtual_cluster_->Provision());
1188 }
TearDown()1189 void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
1190
1191 std::unique_ptr<Cluster> virtual_cluster_;
1192 };
1193
TEST_F(AutoMixedPrecisionMklTest,AlreadyBf16)1194 TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) {
1195 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1196 Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
1197 Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_BFLOAT16);
1198 Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
1199 Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
1200 Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
1201 Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
1202 Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
1203
1204 GrapplerItem item;
1205 item.fetch = {"fetch"};
1206 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1207 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1208
1209 AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
1210 GraphDef output;
1211 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1212 VLOG(1) << output.DebugString();
1213
1214 VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
1215 GraphView output_view(&output);
1216 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1217 EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_BFLOAT16);
1218 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1219 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_BFLOAT16);
1220 EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_BFLOAT16);
1221 EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
1222 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
1223
1224 auto tensors = EvaluateNodes(output, item.fetch);
1225 EXPECT_EQ(tensors.size(), tensors_expected.size());
1226 EXPECT_EQ(tensors.size(), item.fetch.size());
1227 for (int i = 0; i < item.fetch.size(); ++i) {
1228 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
1229 }
1230 }
1231
TEST_F(AutoMixedPrecisionMklTest,Simple)1232 TEST_F(AutoMixedPrecisionMklTest, Simple) {
1233 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1234 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1235 Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
1236 Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
1237 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
1238 Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
1239 Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
1240 Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
1241 Output deny2 = ops::Log(s.WithOpName("deny2"), clr3);
1242 Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
1243 Output deny3 = ops::SparseMatMul(s.WithOpName("deny3"), clr4, clr4);
1244 Output clr5 = ops::Relu(s.WithOpName("clr5"), deny3);
1245 Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
1246
1247 GrapplerItem item;
1248 item.fetch = {"fetch"};
1249 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1250 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1251
1252 AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
1253 GraphDef output;
1254 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1255
1256 VLOG(1) << output.DebugString();
1257
1258 GraphView output_view(&output);
1259 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1260 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1261 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
1262 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
1263 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
1264 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_BFLOAT16);
1265 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1266 EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_BFLOAT16);
1267 EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
1268 EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
1269 EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Ta").type(), DT_FLOAT);
1270 EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Tb").type(), DT_FLOAT);
1271 EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
1272
1273 auto tensors = EvaluateNodes(output, item.fetch);
1274 EXPECT_EQ(tensors.size(), tensors_expected.size());
1275 EXPECT_EQ(tensors.size(), item.fetch.size());
1276 for (int i = 0; i < item.fetch.size(); ++i) {
1277 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
1278 }
1279 }
1280
TEST_F(AutoMixedPrecisionMklTest,TensorListSetGet)1281 TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
1282 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1283 tensorflow::Input shape = {32, 32};
1284 auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
1285 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1286 Output idx1 = ops::Const(s.WithOpName("idx1"), 1);
1287 Output idx2 = ops::Const(s.WithOpName("idx2"), 2);
1288 Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
1289 auto tl1w1 =
1290 ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
1291 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
1292 auto tl1w2 =
1293 ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
1294 // Ensure that TensorListResize doesn't cause any problems.
1295 Output tl1rs =
1296 ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
1297 Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
1298 shape, DT_FLOAT)
1299 .item;
1300 Output infer1 = ops::Mul(s.WithOpName("infer1"), tl1r1, tl1r1);
1301 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
1302 auto tl1w3 =
1303 ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
1304 Output tl1r2 =
1305 ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
1306 shape, DT_FLOAT)
1307 .item;
1308 auto tl2 = ops::TensorListReserve(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
1309 auto tl2w1 =
1310 ops::TensorListSetItem(s.WithOpName("tl2w1"), tl2.handle, idx1, input);
1311 Output tl2r1 =
1312 ops::TensorListGetItem(s.WithOpName("tl2r1"), tl2w1.output_handle, idx1,
1313 shape, DT_FLOAT)
1314 .item;
1315 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
1316 Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
1317
1318 GrapplerItem item;
1319 item.fetch = {"fetch1", "fetch2"};
1320 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1321 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1322
1323 AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
1324 GraphDef output;
1325 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1326
1327 VLOG(1) << output.DebugString();
1328
1329 GraphView output_view(&output);
1330 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1331 const char* type_key = "element_dtype";
1332 EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(),
1333 DT_BFLOAT16);
1334 EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(),
1335 DT_BFLOAT16);
1336 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1337 EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(),
1338 DT_BFLOAT16);
1339 EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(),
1340 DT_BFLOAT16);
1341 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_BFLOAT16);
1342 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_BFLOAT16);
1343 EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(),
1344 DT_BFLOAT16);
1345 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
1346 EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
1347 EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
1348
1349 auto tensors = EvaluateNodes(output, item.fetch);
1350 EXPECT_EQ(tensors.size(), tensors_expected.size());
1351 EXPECT_EQ(tensors.size(), item.fetch.size());
1352 for (int i = 0; i < item.fetch.size(); ++i) {
1353 test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-2);
1354 }
1355 }
1356
1357 #endif // ENABLE_INTEL_MKL_BFLOAT16
1358 #endif // INTEL_MKL
1359
1360 } // namespace
1361 } // namespace grappler
1362 } // namespace tensorflow
1363
1364 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || (INTEL_MKL &&
1365 // defined(ENABLE_INTEL_MKL_BFLOAT16))
1366