1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || INTEL_MKL
17
18 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
19
20 #include <utility>
21 #include <vector>
22
23 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
24 #include "tensorflow/cc/ops/list_ops.h"
25 #include "tensorflow/cc/ops/math_ops.h"
26 #include "tensorflow/cc/ops/standard_ops.h"
27 #include "tensorflow/core/framework/function_testlib.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/tensor_testutil.h"
30 #include "tensorflow/core/grappler/clusters/single_machine.h"
31 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
32 #include "tensorflow/core/grappler/devices.h"
33 #include "tensorflow/core/grappler/graph_view.h"
34 #include "tensorflow/core/grappler/utils/grappler_test.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/lib/random/random.h"
37
38 // TODO(benbarsdell): Improve the numerical checks in these tests. The tests
39 // were originally written only to check the graph coloring, so the graphs do
40 // not have particularly realistic numerical behavior.
41
42 namespace tensorflow {
43 namespace grappler {
44 namespace {
45
46 template <DataType DTYPE>
GenerateIdentityMatrix(int64_t height,int64_t width)47 Tensor GenerateIdentityMatrix(int64_t height, int64_t width) {
48 typedef typename EnumToDataType<DTYPE>::Type T;
49 Tensor tensor(DTYPE, TensorShape{height, width});
50 for (int64_t i = 0; i < height; ++i) {
51 for (int64_t j = 0; j < width; ++j) {
52 tensor.matrix<T>()(i, j) = i == j;
53 }
54 }
55 return tensor;
56 }
57
58 template <DataType DTYPE>
GenerateRandomTensorInRange(const TensorShape & shape,double minval,double maxval)59 Tensor GenerateRandomTensorInRange(const TensorShape& shape, double minval,
60 double maxval) {
61 typedef typename EnumToDataType<DTYPE>::Type T;
62 Tensor tensor(DTYPE, shape);
63 for (auto i = 0; i < tensor.NumElements(); i++)
64 tensor.flat<T>()(i) =
65 (random::New64() % 65536 / 65536.0) * (maxval - minval) + minval;
66 return tensor;
67 }
68
VerifyGraphsEquivalent(const GraphDef & original_graph,const GraphDef & optimized_graph,const string & func)69 void VerifyGraphsEquivalent(const GraphDef& original_graph,
70 const GraphDef& optimized_graph,
71 const string& func) {
72 EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
73 GraphView optimized_view(&optimized_graph);
74 for (int i = 0; i < original_graph.node_size(); ++i) {
75 const NodeDef& original = original_graph.node(i);
76 const NodeDef& optimized = *optimized_view.GetNode(original.name());
77 EXPECT_EQ(original.name(), optimized.name()) << func;
78 EXPECT_EQ(original.op(), optimized.op()) << func;
79 EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
80 if (original.input_size() == optimized.input_size()) {
81 for (int j = 0; j < original.input_size(); ++j) {
82 EXPECT_EQ(original.input(j), optimized.input(j)) << func;
83 }
84 }
85 }
86 }
87
88 // Currently, this test suite only passes when TensorFlow passes with CUDA/HIP,
89 // because otherwise the optimizer will not turn clearlist nodes to float16.
90 // When looking at clearlist nodes, this optimizer checks if the nodes have a
91 // float16 GPU OpKernel, but without CUDA/HIP there are no GPU OpKernels at all.
92 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
93
94 const std::pair<int, int> kMinGPUArch = {7, 0};
95
96 class AutoMixedPrecisionTest : public GrapplerTest {
97 protected:
SetUp()98 void SetUp() override {
99 int num_gpus = GetNumAvailableGPUs();
100 // If GPUs are available, require that they all satisfy the min arch.
101 gpu_available_ = (num_gpus > 0);
102 #if GOOGLE_CUDA
103 gpu_available_ =
104 gpu_available_ && (num_gpus == GetNumAvailableGPUs(kMinGPUArch));
105 #else // Here we force Tensorflow to use the virtual GFX906
106 gpu_available_ = false;
107 #endif
108 if (gpu_available_) {
109 virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 1));
110 } else {
111 DeviceProperties device_properties;
112 device_properties.set_type("GPU");
113 #if GOOGLE_CUDA
114 device_properties.mutable_environment()->insert({"architecture", "7"});
115 device_properties.mutable_environment()->insert({"cuda", "9010"});
116 #else
117 device_properties.mutable_environment()->insert(
118 {"architecture", "gfx906"});
119 #endif
120 virtual_cluster_.reset(
121 new VirtualCluster({{"/GPU:1", device_properties}}));
122 }
123 TF_CHECK_OK(virtual_cluster_->Provision());
124 }
125
TearDown()126 void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
127
AddSimpleNode(const string & name,const string & op,const std::vector<string> & inputs,GraphDef * graph) const128 NodeDef* AddSimpleNode(const string& name, const string& op,
129 const std::vector<string>& inputs,
130 GraphDef* graph) const {
131 std::vector<std::pair<string, AttrValue>> attributes;
132 if (op == "AddN" || op == "ShapeN") {
133 AttrValue num_inputs;
134 num_inputs.set_i(inputs.size());
135 attributes.emplace_back("N", num_inputs);
136 }
137 if (op == "ShapeN") {
138 AttrValue out_type;
139 out_type.set_type(DT_INT32);
140 attributes.emplace_back("out_type", out_type);
141 }
142 AttrValue type;
143 type.set_type(DT_FLOAT);
144 if (op == "Const" || op == "Placeholder" || op == "VariableV2" ||
145 op == "VarHandleOp" || op == "ReadVariableOp") {
146 attributes.emplace_back("dtype", type);
147 } else if (op == "SparseMatMul") {
148 attributes.emplace_back("Ta", type);
149 attributes.emplace_back("Tb", type);
150 } else if (op == "IdentityN") {
151 AttrValue type_list;
152 for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
153 type_list.mutable_list()->add_type(DT_FLOAT);
154 }
155 attributes.emplace_back("T", type_list);
156 } else if (op == "StackV2" || op == "StackPopV2") {
157 attributes.emplace_back("elem_type", type);
158 } else if (op == "Cast") {
159 attributes.emplace_back("SrcT", type);
160 attributes.emplace_back("DstT", type);
161 } else {
162 attributes.emplace_back("T", type);
163 }
164 return AddNode(name, op, inputs, attributes, graph);
165 }
166
TestSimpleUnaryInferOp(double input_min,double input_max,double atol,double rtol,const std::function<Output (const tensorflow::Scope &,Output)> & test_op_factory)167 void TestSimpleUnaryInferOp(
168 double input_min, double input_max, double atol, double rtol,
169 const std::function<Output(const tensorflow::Scope&, Output)>&
170 test_op_factory) {
171 int size = 128;
172 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
173 Output eye = ops::Const(s.WithOpName("eye"),
174 GenerateIdentityMatrix<DT_FLOAT>(size, size));
175 Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
176 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, eye);
177 Output infer1 = test_op_factory(s.WithOpName("infer1"), allow1);
178 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, eye);
179 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
180 GrapplerItem item;
181 item.fetch = {"fetch1"};
182 TF_CHECK_OK(s.ToGraphDef(&item.graph));
183 auto input_tensor = GenerateRandomTensorInRange<DT_FLOAT>(
184 TensorShape({size, size}), input_min, input_max);
185 std::vector<std::pair<string, Tensor>> feed = {{"input", input_tensor}};
186 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
187
188 AutoMixedPrecision optimizer;
189 GraphDef output;
190 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
191
192 VLOG(1) << output.DebugString();
193
194 GraphView output_view(&output);
195 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(),
196 DT_FLOAT);
197 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
198 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
199 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
200
201 auto tensors = EvaluateNodes(output, item.fetch, feed);
202 EXPECT_EQ(tensors.size(), tensors_expected.size());
203 EXPECT_EQ(tensors.size(), item.fetch.size());
204 for (int i = 0; i < item.fetch.size(); ++i) {
205 test::ExpectClose(tensors_expected[i], tensors[i], atol, rtol);
206 }
207 }
208
209 std::unique_ptr<Cluster> virtual_cluster_;
210 bool gpu_available_;
211 };
212
TEST_F(AutoMixedPrecisionTest,NoOp)213 TEST_F(AutoMixedPrecisionTest, NoOp) {
214 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
215 Output input = ops::Const(s.WithOpName("input"), 1.234f, {32});
216 Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
217 Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
218 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
219 Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
220 Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
221
222 GrapplerItem item;
223 item.fetch = {"fetch"};
224 TF_CHECK_OK(s.ToGraphDef(&item.graph));
225 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
226
227 AutoMixedPrecision optimizer;
228 GraphDef output;
229 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
230
231 VLOG(1) << output.DebugString();
232
233 VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
234
235 GraphView output_view(&output);
236 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
237 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
238 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
239 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
240 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
241
242 auto tensors = EvaluateNodes(output, item.fetch);
243 EXPECT_EQ(tensors.size(), tensors_expected.size());
244 EXPECT_EQ(tensors.size(), item.fetch.size());
245 for (int i = 0; i < item.fetch.size(); ++i) {
246 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
247 }
248 }
249
TEST_F(AutoMixedPrecisionTest,AlreadyFp16)250 TEST_F(AutoMixedPrecisionTest, AlreadyFp16) {
251 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
252 Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
253 Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_HALF);
254 Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
255 Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
256 Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
257 Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
258 Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
259
260 GrapplerItem item;
261 item.fetch = {"fetch"};
262 TF_CHECK_OK(s.ToGraphDef(&item.graph));
263 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
264
265 AutoMixedPrecision optimizer;
266 GraphDef output;
267 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
268 VLOG(1) << output.DebugString();
269
270 VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
271 GraphView output_view(&output);
272 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
273 EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
274 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
275 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
276 EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_HALF);
277 EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
278 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
279
280 auto tensors = EvaluateNodes(output, item.fetch);
281 EXPECT_EQ(tensors.size(), tensors_expected.size());
282 EXPECT_EQ(tensors.size(), item.fetch.size());
283 for (int i = 0; i < item.fetch.size(); ++i) {
284 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
285 }
286 }
287
TEST_F(AutoMixedPrecisionTest,Simple)288 TEST_F(AutoMixedPrecisionTest, Simple) {
289 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
290 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
291 Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
292 Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
293 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
294 Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
295 Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
296 Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
297 Output infer2 = ops::Log(s.WithOpName("infer2"), clr3);
298 Output clr4 = ops::Relu(s.WithOpName("clr4"), infer2);
299 Output deny2 = ops::SparseMatMul(s.WithOpName("deny2"), clr4, clr4);
300 Output clr5 = ops::Relu(s.WithOpName("clr5"), deny2);
301 Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
302
303 GrapplerItem item;
304 item.fetch = {"fetch"};
305 TF_CHECK_OK(s.ToGraphDef(&item.graph));
306 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
307
308 AutoMixedPrecision optimizer;
309 GraphDef output;
310 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
311
312 VLOG(1) << output.DebugString();
313
314 GraphView output_view(&output);
315 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
316 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
317 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
318 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
319 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
320 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
321 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
322 EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
323 EXPECT_EQ(output_view.GetNode("infer2")->attr().at("T").type(), DT_FLOAT);
324 EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
325 EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Ta").type(), DT_FLOAT);
326 EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Tb").type(), DT_FLOAT);
327 EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
328
329 auto tensors = EvaluateNodes(output, item.fetch);
330 EXPECT_EQ(tensors.size(), tensors_expected.size());
331 EXPECT_EQ(tensors.size(), item.fetch.size());
332 for (int i = 0; i < item.fetch.size(); ++i) {
333 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
334 }
335 }
336
TEST_F(AutoMixedPrecisionTest,BidirectionalClearChain)337 TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) {
338 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
339 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
340 Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
341 Output clr2 = ops::Relu(s.WithOpName("clr2"), input);
342 Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
343 auto clr3 = ops::ShapeN(s.WithOpName("clr3"), {clr1, clr2});
344 Output clr4 = ops::Relu(s.WithOpName("clr4"), clr2);
345 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
346 Output fetch2 = ops::Identity(s.WithOpName("fetch2"), clr4);
347
348 GrapplerItem item;
349 item.fetch = {"fetch1", "fetch2"};
350 TF_CHECK_OK(s.ToGraphDef(&item.graph));
351 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
352
353 AutoMixedPrecision optimizer;
354 GraphDef output;
355 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
356
357 VLOG(1) << output.DebugString();
358
359 GraphView output_view(&output);
360 EXPECT_EQ(output.node_size(), item.graph.node_size() + 3);
361 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
362 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
363 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
364 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
365 EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
366 EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_HALF);
367
368 auto tensors = EvaluateNodes(output, item.fetch);
369 EXPECT_EQ(tensors.size(), tensors_expected.size());
370 EXPECT_EQ(tensors.size(), item.fetch.size());
371 for (int i = 0; i < item.fetch.size(); ++i) {
372 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
373 }
374 }
375
TEST_F(AutoMixedPrecisionTest,PreserveFetches)376 TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
377 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
378 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
379 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
380 Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
381 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
382 Output deny1 = ops::Exp(s.WithOpName("deny1"), infer1);
383 Output clr2 = ops::Relu(s.WithOpName("clr2"), deny1);
384 Output allow2 = ops::MatMul(s.WithOpName("allow2"), clr2, clr2);
385 Output clr3 = ops::Relu(s.WithOpName("clr3"), allow2);
386 Output deny2 = ops::Exp(s.WithOpName("deny2"), clr3);
387 Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
388
389 GrapplerItem item;
390 item.fetch = {"allow1", "clr2", "clr3"};
391 TF_CHECK_OK(s.ToGraphDef(&item.graph));
392 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
393
394 AutoMixedPrecision optimizer;
395 GraphDef output;
396 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
397
398 VLOG(1) << output.DebugString();
399
400 GraphView output_view(&output);
401 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
402 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
403 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
404 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
405 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
406 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
407 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
408 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
409 EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_FLOAT);
410 EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
411 EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
412
413 auto tensors = EvaluateNodes(output, item.fetch);
414 EXPECT_EQ(tensors.size(), tensors_expected.size());
415 EXPECT_EQ(tensors.size(), item.fetch.size());
416 for (int i = 0; i < item.fetch.size(); ++i) {
417 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-3);
418 }
419 }
420
TEST_F(AutoMixedPrecisionTest,PreserveCPUNodes)421 TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) {
422 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
423 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
424 Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
425 Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
426 Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
427 Output allow2 =
428 ops::MatMul(s.WithOpName("allow2").WithDevice(
429 "/job:localhost/replica:0/task:0/device:CPU:0"),
430 infer1, infer1);
431 Output clr2 = ops::Relu(s.WithOpName("clr2"), allow2);
432 Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
433
434 GrapplerItem item;
435 item.fetch = {"fetch"};
436 TF_CHECK_OK(s.ToGraphDef(&item.graph));
437 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
438
439 AutoMixedPrecision optimizer;
440 GraphDef output;
441 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
442
443 VLOG(1) << output.DebugString();
444
445 GraphView output_view(&output);
446 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
447 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
448 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
449 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
450 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
451 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_FLOAT);
452 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
453
454 auto tensors = EvaluateNodes(output, item.fetch);
455 EXPECT_EQ(tensors.size(), tensors_expected.size());
456 EXPECT_EQ(tensors.size(), item.fetch.size());
457 for (int i = 0; i < item.fetch.size(); ++i) {
458 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
459 }
460 }
461
TEST_F(AutoMixedPrecisionTest,PreserveIdentityAfterVariable)462 TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) {
463 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
464 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
465 Output var1 = ops::Variable(s.WithOpName("var1"), {32, 32}, DT_FLOAT);
466 Output clr1 = ops::Identity(s.WithOpName("clr1"), var1);
467 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, clr1);
468 Output input2 = ops::Const(s.WithOpName("input2"), 1.f / 32, {32, 32});
469 Output clr2 = ops::Identity(s.WithOpName("clr2"), input2);
470 Output allow2 = ops::MatMul(s.WithOpName("allow2"), input, clr2);
471 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
472 Output fetch2 = ops::Identity(s.WithOpName("fetch2"), allow2);
473
474 GrapplerItem item;
475 item.fetch = {"fetch1", "fetch2"};
476 TF_CHECK_OK(s.ToGraphDef(&item.graph));
477 auto var1_tensor =
478 GenerateConstantTensor<DT_FLOAT>(TensorShape({32, 32}), 3.141593f);
479 std::vector<std::pair<string, Tensor>> feed = {{"var1", var1_tensor}};
480 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
481
482 AutoMixedPrecision optimizer;
483 GraphDef output;
484 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
485
486 VLOG(1) << output.DebugString();
487
488 GraphView output_view(&output);
489 EXPECT_EQ(output.node_size(), item.graph.node_size() + 5);
490 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
491 EXPECT_EQ(output_view.GetNode("var1")->attr().at("dtype").type(), DT_FLOAT);
492 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
493 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
494 EXPECT_EQ(output_view.GetNode("input2")->attr().at("dtype").type(), DT_FLOAT);
495 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
496 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
497
498 auto tensors = EvaluateNodes(output, item.fetch, feed);
499 EXPECT_EQ(tensors.size(), tensors_expected.size());
500 EXPECT_EQ(tensors.size(), item.fetch.size());
501 for (int i = 0; i < item.fetch.size(); ++i) {
502 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-3);
503 }
504 }
505
TEST_F(AutoMixedPrecisionTest,FusedBatchNorm)506 TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
507 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
508 // Uses NHWC data format because non-GPU execution does not support NCHW.
509 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {8, 56, 56, 16});
510 Output weight = ops::Const(s.WithOpName("weight"), 2.f, {3, 3, 16, 16});
511 Output scale = ops::Const(s.WithOpName("scale"), 3.f, {16});
512 Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16});
513 Output mean = ops::Const(s.WithOpName("mean"), 5.f, {0});
514 Output variance = ops::Const(s.WithOpName("variance"), 6.f, {0});
515 Output allow1 =
516 ops::Conv2D(s.WithOpName("allow1"), input, weight, {1, 1, 1, 1}, "SAME",
517 ops::Conv2D::DataFormat("NHWC"));
518 auto fbn1_op =
519 ops::FusedBatchNorm(s.WithOpName("fbn1"), allow1, scale, offset, mean,
520 variance, ops::FusedBatchNorm::DataFormat("NHWC"));
521 Output fbn1 = fbn1_op.y;
522 Output fbn1_rs1 = fbn1_op.reserve_space_1;
523 Output fbn1_rs2 = fbn1_op.reserve_space_2;
524 Output bng1 = ops::FusedBatchNormGrad(
525 s.WithOpName("bng1"), fbn1, allow1, scale, fbn1_rs1,
526 fbn1_rs2, ops::FusedBatchNormGrad::DataFormat("NHWC"))
527 .x_backprop;
528 Output infer1 = ops::Add(s.WithOpName("infer1"), fbn1, bng1);
529 Output allow2 =
530 ops::Conv2D(s.WithOpName("allow2"), infer1, weight, {1, 1, 1, 1}, "SAME",
531 ops::Conv2D::DataFormat("NHWC"));
532 Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
533
534 GrapplerItem item;
535 item.fetch = {"fetch"};
536 TF_CHECK_OK(s.ToGraphDef(&item.graph));
537 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
538
539 AutoMixedPrecision optimizer;
540 GraphDef output;
541 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
542
543 VLOG(1) << output.DebugString();
544
545 GraphView output_view(&output);
546 EXPECT_EQ(output.node_size(), item.graph.node_size() + 3);
547 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
548 EXPECT_EQ(output_view.GetNode("fbn1")->op(), "FusedBatchNormV2");
549 EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("T").type(), DT_HALF);
550 EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("U").type(), DT_FLOAT);
551 EXPECT_EQ(output_view.GetNode("bng1")->op(), "FusedBatchNormGradV2");
552 EXPECT_EQ(output_view.GetNode("bng1")->attr().at("T").type(), DT_HALF);
553 EXPECT_EQ(output_view.GetNode("bng1")->attr().at("U").type(), DT_FLOAT);
554 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
555 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
556
557 auto tensors = EvaluateNodes(output, item.fetch);
558 EXPECT_EQ(tensors.size(), tensors_expected.size());
559 EXPECT_EQ(tensors.size(), item.fetch.size());
560 for (int i = 0; i < item.fetch.size(); ++i) {
561 test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-2);
562 }
563 }
564
TEST_F(AutoMixedPrecisionTest,RepeatedAndListTypeAttrs)565 TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
566 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
567 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
568 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
569 auto clr1_op = ops::IdentityN(s.WithOpName("clr1"), {allow1, allow1, allow1});
570 Output infer1 =
571 ops::AddN(s.WithOpName("infer1"),
572 {clr1_op.output[0], clr1_op.output[1], clr1_op.output[2]});
573 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
574 Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
575
576 GrapplerItem item;
577 item.fetch = {"fetch"};
578 TF_CHECK_OK(s.ToGraphDef(&item.graph));
579 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
580
581 AutoMixedPrecision optimizer;
582 GraphDef output;
583 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
584
585 VLOG(1) << output.DebugString();
586
587 GraphView output_view(&output);
588 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
589 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
590 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
591 for (auto type : output_view.GetNode("clr1")->attr().at("T").list().type()) {
592 EXPECT_EQ(type, DT_HALF);
593 }
594 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
595 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
596
597 auto tensors = EvaluateNodes(output, item.fetch);
598 EXPECT_EQ(tensors.size(), tensors_expected.size());
599 EXPECT_EQ(tensors.size(), item.fetch.size());
600 for (int i = 0; i < item.fetch.size(); ++i) {
601 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
602 }
603 }
604
TEST_F(AutoMixedPrecisionTest,ExistingCast)605 TEST_F(AutoMixedPrecisionTest, ExistingCast) {
606 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
607 Output input = ops::Const(s.WithOpName("input"), true, {32, 32});
608 Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_FLOAT);
609 Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
610 Output fetch = ops::Identity(s.WithOpName("fetch"), allow1);
611
612 GrapplerItem item;
613 item.fetch = {"fetch"};
614 TF_CHECK_OK(s.ToGraphDef(&item.graph));
615 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
616
617 AutoMixedPrecision optimizer;
618 GraphDef output;
619 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
620
621 VLOG(1) << output.DebugString();
622
623 GraphView output_view(&output);
624 EXPECT_EQ(output.node_size(), item.graph.node_size() + 1);
625 EXPECT_EQ(output_view.GetNode("cst1")->attr().at("SrcT").type(), DT_BOOL);
626 EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
627 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
628
629 auto tensors = EvaluateNodes(output, item.fetch);
630 EXPECT_EQ(tensors.size(), tensors_expected.size());
631 EXPECT_EQ(tensors.size(), item.fetch.size());
632 for (int i = 0; i < item.fetch.size(); ++i) {
633 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
634 }
635 }
636
TEST_F(AutoMixedPrecisionTest,RecurrentEdgeColorMismatch)637 TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) {
638 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
639 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
640 Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
641 Output ent1 =
642 ops::internal::Enter(s.WithOpName("ent1"), deny1, "loop1").output;
643 // Note that the second input is later replaced with "nxt1".
644 Output mrg1 = ops::Merge(s.WithOpName("mrg1"), {ent1, ent1}).output;
645 // For simplicity, the loop condition is constant false.
646 Output con1 = ops::Const(s.WithOpName("con1"), false, {});
647 Output lpc1 = ops::LoopCond(s.WithOpName("lpc1"), con1).output;
648 auto swt1 = ops::Switch(s.WithOpName("swt1"), mrg1, lpc1);
649 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), swt1.output_true);
650 Output allow1 = ops::MatMul(s.WithOpName("allow1"), infer1, infer1);
651 Output nxt1 = ops::NextIteration(s.WithOpName("nxt1"), allow1);
652 Output ext1 = ops::internal::Exit(s.WithOpName("ext1"), swt1.output_false);
653 Output fetch = ops::Identity(s.WithOpName("fetch"), ext1);
654 // Add a second merge node from the same NextIteration node. This case arises
655 // during graph optimization of some models.
656 auto mrg2 = ops::Merge(s.WithOpName("mrg2"), {ent1, nxt1});
657
658 GrapplerItem item;
659 item.fetch = {"fetch"};
660 TF_CHECK_OK(s.ToGraphDef(&item.graph));
661 NodeMap node_map_original(&item.graph);
662 auto merge_node = node_map_original.GetNode("mrg1");
663 // Modify the graph to create a loop.
664 merge_node->set_input(1, "nxt1");
665 // Add a control edge to ensure the loop condition is inside the frame.
666 auto const_node = node_map_original.GetNode("con1");
667 const_node->add_input("^mrg1");
668 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
669
670 AutoMixedPrecision optimizer;
671 GraphDef output;
672 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
673
674 VLOG(1) << output.DebugString();
675
676 GraphView output_view(&output);
677 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
678 // Note that mrg1 gets painted deny because it is between deny1 and infer1.
679 // This forces nxt1 and mrg2 to be painted deny as well (they would otherwise
680 // be painted allow because they are clear and have a direct path to allow1).
681 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
682 EXPECT_EQ(output_view.GetNode("ent1")->attr().at("T").type(), DT_FLOAT);
683 EXPECT_EQ(output_view.GetNode("mrg1")->attr().at("T").type(), DT_FLOAT);
684 EXPECT_EQ(output_view.GetNode("swt1")->attr().at("T").type(), DT_FLOAT);
685 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
686 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
687 EXPECT_EQ(output_view.GetNode("nxt1")->attr().at("T").type(), DT_FLOAT);
688 EXPECT_EQ(output_view.GetNode("ext1")->attr().at("T").type(), DT_FLOAT);
689 EXPECT_EQ(output_view.GetNode("mrg2")->attr().at("T").type(), DT_FLOAT);
690
691 auto tensors = EvaluateNodes(output, item.fetch);
692 EXPECT_EQ(tensors.size(), tensors_expected.size());
693 EXPECT_EQ(tensors.size(), item.fetch.size());
694 for (int i = 0; i < item.fetch.size(); ++i) {
695 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
696 }
697 }
698
TEST_F(AutoMixedPrecisionTest,TensorListSetGet)699 TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
700 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
701 tensorflow::Input shape = {32, 32};
702 auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
703 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
704 Output idx1 = ops::Const(s.WithOpName("idx1"), 1);
705 Output idx2 = ops::Const(s.WithOpName("idx2"), 2);
706 Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
707 auto tl1w1 =
708 ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
709 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
710 auto tl1w2 =
711 ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
712 // Ensure that TensorListResize doesn't cause any problems.
713 Output tl1rs =
714 ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
715 Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
716 shape, DT_FLOAT)
717 .item;
718 Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
719 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
720 auto tl1w3 =
721 ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
722 Output tl1r2 =
723 ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
724 shape, DT_FLOAT)
725 .item;
726 auto tl2 = ops::TensorListReserve(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
727 auto tl2w1 =
728 ops::TensorListSetItem(s.WithOpName("tl2w1"), tl2.handle, idx1, input);
729 Output tl2r1 =
730 ops::TensorListGetItem(s.WithOpName("tl2r1"), tl2w1.output_handle, idx1,
731 shape, DT_FLOAT)
732 .item;
733 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
734 Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
735
736 GrapplerItem item;
737 item.fetch = {"fetch1", "fetch2"};
738 TF_CHECK_OK(s.ToGraphDef(&item.graph));
739 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
740
741 AutoMixedPrecision optimizer;
742 GraphDef output;
743 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
744
745 VLOG(1) << output.DebugString();
746
747 GraphView output_view(&output);
748 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
749 const char* type_key = "element_dtype";
750 EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
751 EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
752 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
753 EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
754 EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
755 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
756 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
757 EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
758 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
759 EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
760 EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
761
762 auto tensors = EvaluateNodes(output, item.fetch);
763 EXPECT_EQ(tensors.size(), tensors_expected.size());
764 EXPECT_EQ(tensors.size(), item.fetch.size());
765 for (int i = 0; i < item.fetch.size(); ++i) {
766 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
767 }
768 }
769
TEST_F(AutoMixedPrecisionTest,TensorListPushPop)770 TEST_F(AutoMixedPrecisionTest, TensorListPushPop) {
771 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
772 tensorflow::Input shape = {32, 32};
773 auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
774 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
775 auto tl1w1 =
776 ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, input);
777 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
778 auto tl1w2 = ops::TensorListPushBack(s.WithOpName("tl1w2"),
779 tl1w1.output_handle, allow1);
780 Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"),
781 tl1w2.output_handle, shape, DT_FLOAT)
782 .tensor;
783 Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
784 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
785 auto tl1w3 =
786 ops::TensorListPushBack(s.WithOpName("tl1w3"), tl1.handle, allow2);
787 Output tl1r2 = ops::TensorListPopBack(s.WithOpName("tl1r2"),
788 tl1w3.output_handle, shape, DT_FLOAT)
789 .tensor;
790 auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
791 auto tl2w1 =
792 ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, input);
793 Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
794 tl2w1.output_handle, shape, DT_FLOAT)
795 .tensor;
796 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
797 Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
798
799 GrapplerItem item;
800 item.fetch = {"fetch1", "fetch2"};
801 TF_CHECK_OK(s.ToGraphDef(&item.graph));
802 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
803
804 AutoMixedPrecision optimizer;
805 GraphDef output;
806 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
807
808 VLOG(1) << output.DebugString();
809
810 GraphView output_view(&output);
811 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
812 const char* type_key = "element_dtype";
813 EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
814 EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
815 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
816 EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
817 EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
818 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
819 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
820 EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
821 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
822 EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
823 EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
824
825 auto tensors = EvaluateNodes(output, item.fetch);
826 EXPECT_EQ(tensors.size(), tensors_expected.size());
827 EXPECT_EQ(tensors.size(), item.fetch.size());
828 for (int i = 0; i < item.fetch.size(); ++i) {
829 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
830 }
831 }
832
TEST_F(AutoMixedPrecisionTest,TensorListFromTensor)833 TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) {
834 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
835 tensorflow::Input shape = {32};
836 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
837 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
838 auto tl1 = ops::TensorListFromTensor(s.WithOpName("tl1"), allow1, shape);
839 Output tl1r1 = ops::TensorListStack(s.WithOpName("tl1r1"), tl1.output_handle,
840 shape, DT_FLOAT)
841 .tensor;
842 Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
843 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
844 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
845
846 // This tests that a allow-painted object node (tl2) will force an unpainted
847 // client node (tl2w1) to be painted allow as well. (Without the force, tl2w1
848 // would remain unpainted, producing an invalid graph).
849 auto tl2 = ops::TensorListFromTensor(s.WithOpName("tl2"), allow1, shape);
850 auto tl2w1 =
851 ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.output_handle, input);
852
853 GrapplerItem item;
854 item.fetch = {"fetch1"};
855 TF_CHECK_OK(s.ToGraphDef(&item.graph));
856 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
857
858 AutoMixedPrecision optimizer;
859 GraphDef output;
860 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
861
862 VLOG(1) << output.DebugString();
863
864 GraphView output_view(&output);
865 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
866 const char* type_key = "element_dtype";
867 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
868 EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
869 EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
870 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
871 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
872 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
873 EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
874
875 auto tensors = EvaluateNodes(output, item.fetch);
876 EXPECT_EQ(tensors.size(), tensors_expected.size());
877 EXPECT_EQ(tensors.size(), item.fetch.size());
878 for (int i = 0; i < item.fetch.size(); ++i) {
879 test::ExpectClose(tensors_expected[i], tensors[i], -1, 2e-4);
880 }
881 }
882
TEST_F(AutoMixedPrecisionTest,TensorListPushBackBatchAndConcatLists)883 TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
884 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
885 tensorflow::Input shape = {32, 32};
886 auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
887 auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
888 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
889 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
890 Output tl1_tl2 =
891 ops::Stack(s.WithOpName("tl1_tl2"), {tl1.handle, tl2.handle});
892 Output allow1_allow1 =
893 ops::Stack(s.WithOpName("allow1_allow1"), {allow1, allow1});
894 auto tl12w1 = ops::TensorListPushBackBatch(s.WithOpName("tl12w1"), tl1_tl2,
895 allow1_allow1);
896 OutputList tl12w1_outputs =
897 ops::Split(s.WithOpName("tl12w1_outputs"), 0, tl12w1.output_handles, 2)
898 .output;
899 Output scalar_shape = ops::Const(s.WithOpName("scalar_shape"), 0, {0});
900 Output tl12w1_output0 = ops::Reshape(s.WithOpName("tl12w1_output0"),
901 tl12w1_outputs[0], scalar_shape);
902 Output tl12w1_output1 = ops::Reshape(s.WithOpName("tl12w1_output1"),
903 tl12w1_outputs[1], scalar_shape);
904 Output tl3 = ops::TensorListConcatLists(s.WithOpName("tl3"), tl12w1_output0,
905 tl12w1_output1, DT_FLOAT);
906 Output tl3r1 =
907 ops::TensorListPopBack(s.WithOpName("tl3r1"), tl3, shape, DT_FLOAT)
908 .tensor;
909 Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl3r1);
910 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
911 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
912
913 GrapplerItem item;
914 item.fetch = {"fetch1"};
915 TF_CHECK_OK(s.ToGraphDef(&item.graph));
916 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
917
918 AutoMixedPrecision optimizer;
919 GraphDef output;
920 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
921
922 VLOG(1) << output.DebugString();
923
924 GraphView output_view(&output);
925 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
926 const char* type_key = "element_dtype";
927 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
928 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
929 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
930 EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
931 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
932 EXPECT_EQ(output_view.GetNode("tl3")->attr().at(type_key).type(), DT_HALF);
933 EXPECT_EQ(output_view.GetNode("tl3r1")->attr().at(type_key).type(), DT_HALF);
934
935 auto tensors = EvaluateNodes(output, item.fetch);
936 EXPECT_EQ(tensors.size(), tensors_expected.size());
937 EXPECT_EQ(tensors.size(), item.fetch.size());
938 for (int i = 0; i < item.fetch.size(); ++i) {
939 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
940 }
941 }
942
TEST_F(AutoMixedPrecisionTest,TensorListThroughFunction)943 TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
944 // This test passes a tensor list handle through a function with its own
945 // Tensor List ops inside to test that the types are not changed to a
946 // conflicting state.
947 // A separate Tensor List cluster is added to test that it is still changed to
948 // DT_HALF.
949 FunctionDefLibrary function_lib;
950 const Tensor kShape = test::AsTensor<int32>({32, 32});
951 FunctionDef func1 = FunctionDefHelper::Define(
952 "Func1", {"ihandle: variant", "x: float"},
953 {"ohandle: variant", "y: float"}, {},
954 {
955 {{"tl1w1_handle"},
956 "TensorListPushBack",
957 {"ihandle", "x"},
958 {{"element_dtype", DT_FLOAT}}},
959 {{"shape"}, "Const", {}, {{"value", kShape}, {"dtype", DT_INT32}}},
960 {{"tl1r1_handle", "tl1r1_data"},
961 "TensorListPopBack",
962 {"tl1w1_handle", "shape"},
963 {{"element_dtype", DT_FLOAT}}},
964 {{"ohandle"}, "Identity", {"tl1r1_handle"}, {{"T", DT_VARIANT}}},
965 {{"y"}, "Identity", {"tl1r1_data"}, {{"T", DT_FLOAT}}},
966 });
967 function_lib.add_function()->Swap(&func1);
968
969 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
970 TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib));
971 tensorflow::Input shape = {32, 32};
972 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
973 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
974 Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
975 auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
976 auto tl1w1 =
977 ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, infer1);
978 auto _infer1 = tensorflow::ops::AsNodeOut(s, infer1);
979 auto _tl1w1_handle = tensorflow::ops::AsNodeOut(s, tl1w1.output_handle);
980 auto builder =
981 tensorflow::NodeBuilder("Func1", "Func1", s.graph()->op_registry());
982 tensorflow::Node* func1_op;
983 TF_CHECK_OK(builder.Input(_tl1w1_handle)
984 .Input(_infer1)
985 .Finalize(s.graph(), &func1_op));
986 Output func1_handle(func1_op, 0);
987 Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"), func1_handle,
988 shape, DT_FLOAT)
989 .tensor;
990 auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
991 auto tl2w1 =
992 ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, infer1);
993 Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
994 tl2w1.output_handle, shape, DT_FLOAT)
995 .tensor;
996 Output allow2 = ops::MatMul(s.WithOpName("allow2"), tl1r1, tl2r1);
997 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
998
999 GrapplerItem item;
1000 item.fetch = {"fetch1"};
1001 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1002 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1003
1004 AutoMixedPrecision optimizer;
1005 GraphDef output;
1006 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1007
1008 VLOG(1) << output.DebugString();
1009
1010 GraphView output_view(&output);
1011 const char* type_key = "element_dtype";
1012 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
1013 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
1014 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
1015 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
1016 EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
1017 EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_HALF);
1018
1019 auto tensors = EvaluateNodes(output, item.fetch);
1020 EXPECT_EQ(tensors.size(), tensors_expected.size());
1021 EXPECT_EQ(tensors.size(), item.fetch.size());
1022 for (int i = 0; i < item.fetch.size(); ++i) {
1023 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
1024 }
1025 }
1026
GetCudaVersion(const Cluster & cluster)1027 int GetCudaVersion(const Cluster& cluster) {
1028 auto devices = cluster.GetDevices();
1029 for (const auto& device : devices) {
1030 const DeviceProperties& device_properties = device.second;
1031 if (device_properties.type() == "GPU") {
1032 const auto& device_env = device_properties.environment();
1033 auto it = device_env.find("cuda");
1034 if (it != device_env.end()) {
1035 string cuda_version_str = it->second;
1036 return std::stoi(cuda_version_str);
1037 }
1038 }
1039 }
1040 return 0;
1041 }
1042
IsSupportedGPU(const Cluster & cluster)1043 bool IsSupportedGPU(const Cluster& cluster) {
1044 #ifdef GOOGLE_CUDA
1045 return GetCudaVersion(cluster) >= 9010;
1046 #else
1047 return true;
1048 #endif
1049 }
1050
TEST_F(AutoMixedPrecisionTest,BatchMatMul)1051 TEST_F(AutoMixedPrecisionTest, BatchMatMul) {
1052 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1053 Output input = ops::Const(s.WithOpName("input"), 1.f / 33, {64, 32, 32});
1054 Output allow1 = ops::BatchMatMul(s.WithOpName("allow1"), input, input);
1055 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
1056
1057 GrapplerItem item;
1058 item.fetch = {"fetch1"};
1059 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1060 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1061
1062 AutoMixedPrecision optimizer;
1063 GraphDef output;
1064 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1065
1066 VLOG(1) << output.DebugString();
1067
1068 GraphView output_view(&output);
1069 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1070 if (IsSupportedGPU(*virtual_cluster_.get())) {
1071 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1072 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
1073 } else {
1074 EXPECT_EQ(output.node_size(), item.graph.node_size());
1075 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
1076 }
1077
1078 auto tensors = EvaluateNodes(output, item.fetch);
1079 EXPECT_EQ(tensors.size(), tensors_expected.size());
1080 EXPECT_EQ(tensors.size(), item.fetch.size());
1081 for (int i = 0; i < item.fetch.size(); ++i) {
1082 test::ExpectClose(tensors_expected[i], tensors[i], -1, 3.0e-3);
1083 }
1084 }
1085
TEST_F(AutoMixedPrecisionTest,EluOp)1086 TEST_F(AutoMixedPrecisionTest, EluOp) {
1087 TestSimpleUnaryInferOp(
1088 -5, 5, 1.0e-3, 1.0e-3,
1089 [](const tensorflow::Scope& scope, Output input) -> Output {
1090 return ops::Elu(scope, input);
1091 });
1092 }
1093
TEST_F(AutoMixedPrecisionTest,ErfOp)1094 TEST_F(AutoMixedPrecisionTest, ErfOp) {
1095 TestSimpleUnaryInferOp(
1096 -5, 5, 1.0e-3, -1,
1097 [](const tensorflow::Scope& scope, Output input) -> Output {
1098 return ops::Erf(scope, input);
1099 });
1100 }
1101
TEST_F(AutoMixedPrecisionTest,ErfcOp)1102 TEST_F(AutoMixedPrecisionTest, ErfcOp) {
1103 TestSimpleUnaryInferOp(
1104 -5, 5, 1.0e-3, -1,
1105 [](const tensorflow::Scope& scope, Output input) -> Output {
1106 return ops::Erfc(scope, input);
1107 });
1108 }
1109
TEST_F(AutoMixedPrecisionTest,InvOp)1110 TEST_F(AutoMixedPrecisionTest, InvOp) {
1111 TestSimpleUnaryInferOp(
1112 0.01, 10, -1, 1.0e-3,
1113 [](const tensorflow::Scope& scope, Output input) -> Output {
1114 return ops::Inv(scope, input);
1115 });
1116 }
1117
TEST_F(AutoMixedPrecisionTest,LogOp)1118 TEST_F(AutoMixedPrecisionTest, LogOp) {
1119 TestSimpleUnaryInferOp(
1120 0.01, 10, 1.0e-3, 2.0e-3,
1121 [](const tensorflow::Scope& scope, Output input) -> Output {
1122 return ops::Log(scope, input);
1123 });
1124 }
1125
TEST_F(AutoMixedPrecisionTest,Log1pOp)1126 TEST_F(AutoMixedPrecisionTest, Log1pOp) {
1127 TestSimpleUnaryInferOp(
1128 -0.99, 9, 1.0e-3, 5.0e-3,
1129 [](const tensorflow::Scope& scope, Output input) -> Output {
1130 return ops::Log1p(scope, input);
1131 });
1132 }
1133
TEST_F(AutoMixedPrecisionTest,LogSoftmaxOp)1134 TEST_F(AutoMixedPrecisionTest, LogSoftmaxOp) {
1135 TestSimpleUnaryInferOp(
1136 -8, 8, -1, 1.0e-2,
1137 [](const tensorflow::Scope& scope, Output input) -> Output {
1138 return ops::LogSoftmax(scope, input);
1139 });
1140 }
1141
TEST_F(AutoMixedPrecisionTest,ReciprocalOp)1142 TEST_F(AutoMixedPrecisionTest, ReciprocalOp) {
1143 TestSimpleUnaryInferOp(
1144 0.01, 10, -1, 1.0e-3,
1145 [](const tensorflow::Scope& scope, Output input) -> Output {
1146 return ops::Reciprocal(scope, input);
1147 });
1148 }
1149
TEST_F(AutoMixedPrecisionTest,SigmoidOp)1150 TEST_F(AutoMixedPrecisionTest, SigmoidOp) {
1151 TestSimpleUnaryInferOp(
1152 -5, 5, 1.0e-3, -1,
1153 [](const tensorflow::Scope& scope, Output input) -> Output {
1154 return ops::Sigmoid(scope, input);
1155 });
1156 }
1157
TEST_F(AutoMixedPrecisionTest,SoftmaxOp)1158 TEST_F(AutoMixedPrecisionTest, SoftmaxOp) {
1159 TestSimpleUnaryInferOp(
1160 -8, 8, 2.0e-3, -1,
1161 [](const tensorflow::Scope& scope, Output input) -> Output {
1162 return ops::Softmax(scope, input);
1163 });
1164 }
1165
TEST_F(AutoMixedPrecisionTest,SoftplusOp)1166 TEST_F(AutoMixedPrecisionTest, SoftplusOp) {
1167 TestSimpleUnaryInferOp(
1168 -5, 5, 1.0e-3, 1.0e-3,
1169 [](const tensorflow::Scope& scope, Output input) -> Output {
1170 return ops::Softplus(scope, input);
1171 });
1172 }
1173
TEST_F(AutoMixedPrecisionTest,SqrtOp)1174 TEST_F(AutoMixedPrecisionTest, SqrtOp) {
1175 TestSimpleUnaryInferOp(
1176 0, 10, 1.0e-3, 1.0e-3,
1177 [](const tensorflow::Scope& scope, Output input) -> Output {
1178 return ops::Sqrt(scope, input);
1179 });
1180 }
1181
TEST_F(AutoMixedPrecisionTest,TanhOp)1182 TEST_F(AutoMixedPrecisionTest, TanhOp) {
1183 TestSimpleUnaryInferOp(
1184 -5, 5, 1.0e-3, -1,
1185 [](const tensorflow::Scope& scope, Output input) -> Output {
1186 return ops::Tanh(scope, input);
1187 });
1188 }
1189
1190 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1191
1192 #if INTEL_MKL
1193
1194 class AutoMixedPrecisionMklTest : public GrapplerTest {
1195 protected:
SetUp()1196 void SetUp() override {
1197 virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0));
1198 TF_CHECK_OK(virtual_cluster_->Provision());
1199 }
TearDown()1200 void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
1201
1202 std::unique_ptr<Cluster> virtual_cluster_;
1203 };
1204
TEST_F(AutoMixedPrecisionMklTest,AlreadyBf16)1205 TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) {
1206 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1207 "/job:localhost/replica:0/task:0/device:CPU:0");
1208 Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
1209 Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_BFLOAT16);
1210 Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
1211 Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
1212 Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
1213 Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
1214 Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
1215
1216 GrapplerItem item;
1217 item.fetch = {"fetch"};
1218 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1219 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1220
1221 AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
1222 GraphDef output;
1223 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1224 VLOG(1) << output.DebugString();
1225
1226 VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
1227 GraphView output_view(&output);
1228 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1229 EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_BFLOAT16);
1230 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1231 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_BFLOAT16);
1232 EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_BFLOAT16);
1233 EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
1234 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
1235
1236 auto tensors = EvaluateNodes(output, item.fetch);
1237 EXPECT_EQ(tensors.size(), tensors_expected.size());
1238 EXPECT_EQ(tensors.size(), item.fetch.size());
1239 for (int i = 0; i < item.fetch.size(); ++i) {
1240 test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
1241 }
1242 }
1243
TEST_F(AutoMixedPrecisionMklTest,Simple)1244 TEST_F(AutoMixedPrecisionMklTest, Simple) {
1245 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1246 "/job:localhost/replica:0/task:0/device:CPU:0");
1247 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1248 Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
1249 Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
1250 Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
1251 Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
1252 Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
1253 Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
1254 Output deny2 = ops::Log(s.WithOpName("deny2"), clr3);
1255 Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
1256 Output deny3 = ops::SparseMatMul(s.WithOpName("deny3"), clr4, clr4);
1257 Output clr5 = ops::Relu(s.WithOpName("clr5"), deny3);
1258 Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
1259
1260 GrapplerItem item;
1261 item.fetch = {"fetch"};
1262 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1263 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1264
1265 AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
1266 GraphDef output;
1267 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1268
1269 VLOG(1) << output.DebugString();
1270
1271 GraphView output_view(&output);
1272 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1273 EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1274 EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
1275 EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
1276 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
1277 EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_BFLOAT16);
1278 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1279 EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_BFLOAT16);
1280 EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
1281 EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
1282 EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Ta").type(), DT_FLOAT);
1283 EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Tb").type(), DT_FLOAT);
1284 EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
1285
1286 auto tensors = EvaluateNodes(output, item.fetch);
1287 EXPECT_EQ(tensors.size(), tensors_expected.size());
1288 EXPECT_EQ(tensors.size(), item.fetch.size());
1289 for (int i = 0; i < item.fetch.size(); ++i) {
1290 test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
1291 }
1292 }
1293
TEST_F(AutoMixedPrecisionMklTest,TensorListSetGet)1294 TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
1295 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1296 "/job:localhost/replica:0/task:0/device:CPU:0");
1297 tensorflow::Input shape = {32, 32};
1298 auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
1299 Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1300 Output idx1 = ops::Const(s.WithOpName("idx1"), 1);
1301 Output idx2 = ops::Const(s.WithOpName("idx2"), 2);
1302 Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
1303 auto tl1w1 =
1304 ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
1305 Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
1306 auto tl1w2 =
1307 ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
1308 // Ensure that TensorListResize doesn't cause any problems.
1309 Output tl1rs =
1310 ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
1311 Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
1312 shape, DT_FLOAT)
1313 .item;
1314 Output infer1 = ops::Mul(s.WithOpName("infer1"), tl1r1, tl1r1);
1315 Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
1316 auto tl1w3 =
1317 ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
1318 Output tl1r2 =
1319 ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
1320 shape, DT_FLOAT)
1321 .item;
1322 auto tl2 = ops::TensorListReserve(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
1323 auto tl2w1 =
1324 ops::TensorListSetItem(s.WithOpName("tl2w1"), tl2.handle, idx1, input);
1325 Output tl2r1 =
1326 ops::TensorListGetItem(s.WithOpName("tl2r1"), tl2w1.output_handle, idx1,
1327 shape, DT_FLOAT)
1328 .item;
1329 Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
1330 Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
1331
1332 GrapplerItem item;
1333 item.fetch = {"fetch1", "fetch2"};
1334 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1335 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1336
1337 AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
1338 GraphDef output;
1339 TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1340
1341 VLOG(1) << output.DebugString();
1342
1343 GraphView output_view(&output);
1344 EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1345 const char* type_key = "element_dtype";
1346 EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(),
1347 DT_BFLOAT16);
1348 EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(),
1349 DT_BFLOAT16);
1350 EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1351 EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(),
1352 DT_BFLOAT16);
1353 EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(),
1354 DT_BFLOAT16);
1355 EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_BFLOAT16);
1356 EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_BFLOAT16);
1357 EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(),
1358 DT_BFLOAT16);
1359 EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
1360 EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
1361 EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
1362
1363 auto tensors = EvaluateNodes(output, item.fetch);
1364 EXPECT_EQ(tensors.size(), tensors_expected.size());
1365 EXPECT_EQ(tensors.size(), item.fetch.size());
1366 for (int i = 0; i < item.fetch.size(); ++i) {
1367 test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-2);
1368 }
1369 }
1370
1371 #endif // INTEL_MKL
1372
1373 } // namespace
1374 } // namespace grappler
1375 } // namespace tensorflow
1376
1377 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || INTEL_MKL
1378