• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || INTEL_MKL
17 
18 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
19 
20 #include <utility>
21 #include <vector>
22 
23 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
24 #include "tensorflow/cc/ops/list_ops.h"
25 #include "tensorflow/cc/ops/math_ops.h"
26 #include "tensorflow/cc/ops/standard_ops.h"
27 #include "tensorflow/core/framework/function_testlib.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/tensor_testutil.h"
30 #include "tensorflow/core/grappler/clusters/single_machine.h"
31 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
32 #include "tensorflow/core/grappler/devices.h"
33 #include "tensorflow/core/grappler/graph_view.h"
34 #include "tensorflow/core/grappler/utils/grappler_test.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/lib/random/random.h"
37 #include "tensorflow/core/util/util.h"
38 
39 // TODO(benbarsdell): Improve the numerical checks in these tests. The tests
40 // were originally written only to check the graph coloring, so the graphs do
41 // not have particularly realistic numerical behavior.
42 
43 namespace tensorflow {
44 namespace grappler {
45 namespace {
46 
47 template <DataType DTYPE>
GenerateIdentityMatrix(int64_t height,int64_t width)48 Tensor GenerateIdentityMatrix(int64_t height, int64_t width) {
49   typedef typename EnumToDataType<DTYPE>::Type T;
50   Tensor tensor(DTYPE, TensorShape{height, width});
51   for (int64_t i = 0; i < height; ++i) {
52     for (int64_t j = 0; j < width; ++j) {
53       tensor.matrix<T>()(i, j) = i == j;
54     }
55   }
56   return tensor;
57 }
58 
59 template <DataType DTYPE>
GenerateRandomTensorInRange(const TensorShape & shape,double minval,double maxval)60 Tensor GenerateRandomTensorInRange(const TensorShape& shape, double minval,
61                                    double maxval) {
62   typedef typename EnumToDataType<DTYPE>::Type T;
63   Tensor tensor(DTYPE, shape);
64   for (auto i = 0; i < tensor.NumElements(); i++)
65     tensor.flat<T>()(i) =
66         (random::New64() % 65536 / 65536.0) * (maxval - minval) + minval;
67   return tensor;
68 }
69 
VerifyGraphsEquivalent(const GraphDef & original_graph,const GraphDef & optimized_graph,const string & func)70 void VerifyGraphsEquivalent(const GraphDef& original_graph,
71                             const GraphDef& optimized_graph,
72                             const string& func) {
73   EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
74   GraphView optimized_view(&optimized_graph);
75   for (int i = 0; i < original_graph.node_size(); ++i) {
76     const NodeDef& original = original_graph.node(i);
77     const NodeDef& optimized = *optimized_view.GetNode(original.name());
78     EXPECT_EQ(original.name(), optimized.name()) << func;
79     EXPECT_EQ(original.op(), optimized.op()) << func;
80     EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
81     if (original.input_size() == optimized.input_size()) {
82       for (int j = 0; j < original.input_size(); ++j) {
83         EXPECT_EQ(original.input(j), optimized.input(j)) << func;
84       }
85     }
86   }
87 }
88 
89 // Currently, this test suite only passes when TensorFlow passes with CUDA/HIP,
90 // because otherwise the optimizer will not turn clearlist nodes to float16.
91 // When looking at clearlist nodes, this optimizer checks if the nodes have a
92 // float16 GPU OpKernel, but without CUDA/HIP there are no GPU OpKernels at all.
93 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
94 
95 const std::pair<int, int> kMinGPUArch = {7, 0};
96 
97 class AutoMixedPrecisionTest : public GrapplerTest {
98  protected:
SetUp()99   void SetUp() override {
100     int num_gpus = GetNumAvailableGPUs();
101     // If GPUs are available, require that they all satisfy the min arch.
102     gpu_available_ = (num_gpus > 0);
103 #if GOOGLE_CUDA
104     gpu_available_ =
105         gpu_available_ && (num_gpus == GetNumAvailableGPUs(kMinGPUArch));
106 #else  // Here we force Tensorflow to use the virtual GFX906
107     gpu_available_ = false;
108 #endif
109     if (gpu_available_) {
110       virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 1));
111     } else {
112       DeviceProperties device_properties;
113       device_properties.set_type("GPU");
114 #if GOOGLE_CUDA
115       device_properties.mutable_environment()->insert({"architecture", "7"});
116       device_properties.mutable_environment()->insert({"cuda", "9010"});
117 #else
118       device_properties.mutable_environment()->insert(
119           {"architecture", "gfx906"});
120 #endif
121       virtual_cluster_.reset(
122           new VirtualCluster({{"/GPU:1", device_properties}}));
123     }
124     TF_CHECK_OK(virtual_cluster_->Provision());
125   }
126 
TearDown()127   void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
128 
AddSimpleNode(const string & name,const string & op,const std::vector<string> & inputs,GraphDef * graph) const129   NodeDef* AddSimpleNode(const string& name, const string& op,
130                          const std::vector<string>& inputs,
131                          GraphDef* graph) const {
132     std::vector<std::pair<string, AttrValue>> attributes;
133     if (op == "AddN" || op == "ShapeN") {
134       AttrValue num_inputs;
135       num_inputs.set_i(inputs.size());
136       attributes.emplace_back("N", num_inputs);
137     }
138     if (op == "ShapeN") {
139       AttrValue out_type;
140       out_type.set_type(DT_INT32);
141       attributes.emplace_back("out_type", out_type);
142     }
143     AttrValue type;
144     type.set_type(DT_FLOAT);
145     if (op == "Const" || op == "Placeholder" || op == "VariableV2" ||
146         op == "VarHandleOp" || op == "ReadVariableOp") {
147       attributes.emplace_back("dtype", type);
148     } else if (op == "SparseMatMul") {
149       attributes.emplace_back("Ta", type);
150       attributes.emplace_back("Tb", type);
151     } else if (op == "IdentityN") {
152       AttrValue type_list;
153       for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
154         type_list.mutable_list()->add_type(DT_FLOAT);
155       }
156       attributes.emplace_back("T", type_list);
157     } else if (op == "StackV2" || op == "StackPopV2") {
158       attributes.emplace_back("elem_type", type);
159     } else if (op == "Cast") {
160       attributes.emplace_back("SrcT", type);
161       attributes.emplace_back("DstT", type);
162     } else {
163       attributes.emplace_back("T", type);
164     }
165     return AddNode(name, op, inputs, attributes, graph);
166   }
167 
TestSimpleUnaryInferOp(double input_min,double input_max,double atol,double rtol,const std::function<Output (const tensorflow::Scope &,Output)> & test_op_factory)168   void TestSimpleUnaryInferOp(
169       double input_min, double input_max, double atol, double rtol,
170       const std::function<Output(const tensorflow::Scope&, Output)>&
171           test_op_factory) {
172     int size = 128;
173     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
174     Output eye = ops::Const(s.WithOpName("eye"),
175                             GenerateIdentityMatrix<DT_FLOAT>(size, size));
176     Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
177     Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, eye);
178     Output infer1 = test_op_factory(s.WithOpName("infer1"), allow1);
179     Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, eye);
180     Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
181     GrapplerItem item;
182     item.fetch = {"fetch1"};
183     TF_CHECK_OK(s.ToGraphDef(&item.graph));
184     auto input_tensor = GenerateRandomTensorInRange<DT_FLOAT>(
185         TensorShape({size, size}), input_min, input_max);
186     std::vector<std::pair<string, Tensor>> feed = {{"input", input_tensor}};
187     auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
188 
189     AutoMixedPrecision optimizer;
190     GraphDef output;
191     TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
192 
193     VLOG(1) << output.DebugString();
194 
195     GraphView output_view(&output);
196     EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(),
197               DT_FLOAT);
198     EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
199     EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
200     EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
201 
202     auto tensors = EvaluateNodes(output, item.fetch, feed);
203     EXPECT_EQ(tensors.size(), tensors_expected.size());
204     EXPECT_EQ(tensors.size(), item.fetch.size());
205     for (int i = 0; i < item.fetch.size(); ++i) {
206       test::ExpectClose(tensors_expected[i], tensors[i], atol, rtol);
207     }
208   }
209 
210   std::unique_ptr<Cluster> virtual_cluster_;
211   bool gpu_available_;
212 };
213 
TEST_F(AutoMixedPrecisionTest,NoOp)214 TEST_F(AutoMixedPrecisionTest, NoOp) {
215   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
216   Output input = ops::Const(s.WithOpName("input"), 1.234f, {32});
217   Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
218   Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
219   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
220   Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
221   Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
222 
223   GrapplerItem item;
224   item.fetch = {"fetch"};
225   TF_CHECK_OK(s.ToGraphDef(&item.graph));
226   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
227 
228   AutoMixedPrecision optimizer;
229   GraphDef output;
230   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
231 
232   VLOG(1) << output.DebugString();
233 
234   VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
235 
236   GraphView output_view(&output);
237   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
238   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
239   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
240   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
241   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
242 
243   auto tensors = EvaluateNodes(output, item.fetch);
244   EXPECT_EQ(tensors.size(), tensors_expected.size());
245   EXPECT_EQ(tensors.size(), item.fetch.size());
246   for (int i = 0; i < item.fetch.size(); ++i) {
247     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
248   }
249 }
250 
TEST_F(AutoMixedPrecisionTest,AlreadyFp16)251 TEST_F(AutoMixedPrecisionTest, AlreadyFp16) {
252   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
253   Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
254   Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_HALF);
255   Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
256   Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
257   Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
258   Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
259   Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
260 
261   GrapplerItem item;
262   item.fetch = {"fetch"};
263   TF_CHECK_OK(s.ToGraphDef(&item.graph));
264   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
265 
266   AutoMixedPrecision optimizer;
267   GraphDef output;
268   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
269   VLOG(1) << output.DebugString();
270 
271   VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
272   GraphView output_view(&output);
273   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
274   EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
275   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
276   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
277   EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_HALF);
278   EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
279   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
280 
281   auto tensors = EvaluateNodes(output, item.fetch);
282   EXPECT_EQ(tensors.size(), tensors_expected.size());
283   EXPECT_EQ(tensors.size(), item.fetch.size());
284   for (int i = 0; i < item.fetch.size(); ++i) {
285     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
286   }
287 }
288 
TEST_F(AutoMixedPrecisionTest,Simple)289 TEST_F(AutoMixedPrecisionTest, Simple) {
290   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
291   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
292   Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
293   Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
294   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
295   Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
296   Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
297   Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
298   Output infer2 = ops::Log(s.WithOpName("infer2"), clr3);
299   Output clr4 = ops::Relu(s.WithOpName("clr4"), infer2);
300   Output deny2 = ops::SparseMatMul(s.WithOpName("deny2"), clr4, clr4);
301   Output clr5 = ops::Relu(s.WithOpName("clr5"), deny2);
302   Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
303 
304   GrapplerItem item;
305   item.fetch = {"fetch"};
306   TF_CHECK_OK(s.ToGraphDef(&item.graph));
307   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
308 
309   AutoMixedPrecision optimizer;
310   GraphDef output;
311   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
312 
313   VLOG(1) << output.DebugString();
314 
315   GraphView output_view(&output);
316   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
317   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
318   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
319   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
320   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
321   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
322   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
323   EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
324   EXPECT_EQ(output_view.GetNode("infer2")->attr().at("T").type(), DT_FLOAT);
325   EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
326   EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Ta").type(), DT_FLOAT);
327   EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Tb").type(), DT_FLOAT);
328   EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
329 
330   auto tensors = EvaluateNodes(output, item.fetch);
331   EXPECT_EQ(tensors.size(), tensors_expected.size());
332   EXPECT_EQ(tensors.size(), item.fetch.size());
333   for (int i = 0; i < item.fetch.size(); ++i) {
334     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
335   }
336 }
337 
TEST_F(AutoMixedPrecisionTest,NoInferOp)338 TEST_F(AutoMixedPrecisionTest, NoInferOp) {
339   setenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "TREAT_INFER_AS_DENY",
340          1 /* replace */);
341 
342   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
343   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
344   Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
345   Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
346   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
347   Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
348   Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
349   Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
350   Output infer2 = ops::Log(s.WithOpName("infer2"), clr3);
351   Output clr4 = ops::Relu(s.WithOpName("clr4"), infer2);
352   Output allow2 = ops::MatMul(s.WithOpName("allow2"), clr4, clr4);
353   Output infer3 = ops::Log(s.WithOpName("infer3"), allow2);
354   Output fetch = ops::Identity(s.WithOpName("fetch"), infer3);
355 
356   GrapplerItem item;
357   item.fetch = {"fetch"};
358   TF_CHECK_OK(s.ToGraphDef(&item.graph));
359   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
360 
361   AutoMixedPrecision optimizer;
362   GraphDef output;
363   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
364 
365   VLOG(1) << output.DebugString();
366 
367   GraphView output_view(&output);
368   EXPECT_EQ(output.node_size(), item.graph.node_size() + 4);
369   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
370   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
371   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
372   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
373   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
374   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
375   EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
376   EXPECT_EQ(output_view.GetNode("infer2")->attr().at("T").type(), DT_FLOAT);
377   EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_HALF);
378   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
379   EXPECT_EQ(output_view.GetNode("infer3")->attr().at("T").type(), DT_FLOAT);
380 
381   auto tensors = EvaluateNodes(output, item.fetch);
382   EXPECT_EQ(tensors.size(), tensors_expected.size());
383   EXPECT_EQ(tensors.size(), item.fetch.size());
384   for (int i = 0; i < item.fetch.size(); ++i) {
385     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
386   }
387   unsetenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL");
388 }
389 
TEST_F(AutoMixedPrecisionTest,BidirectionalClearChain)390 TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) {
391   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
392   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
393   Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
394   Output clr2 = ops::Relu(s.WithOpName("clr2"), input);
395   Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
396   auto clr3 = ops::ShapeN(s.WithOpName("clr3"), {clr1, clr2});
397   Output clr4 = ops::Relu(s.WithOpName("clr4"), clr2);
398   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
399   Output fetch2 = ops::Identity(s.WithOpName("fetch2"), clr4);
400 
401   GrapplerItem item;
402   item.fetch = {"fetch1", "fetch2"};
403   TF_CHECK_OK(s.ToGraphDef(&item.graph));
404   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
405 
406   AutoMixedPrecision optimizer;
407   GraphDef output;
408   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
409 
410   VLOG(1) << output.DebugString();
411 
412   GraphView output_view(&output);
413   EXPECT_EQ(output.node_size(), item.graph.node_size() + 3);
414   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
415   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
416   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
417   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
418   EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
419   EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_HALF);
420 
421   auto tensors = EvaluateNodes(output, item.fetch);
422   EXPECT_EQ(tensors.size(), tensors_expected.size());
423   EXPECT_EQ(tensors.size(), item.fetch.size());
424   for (int i = 0; i < item.fetch.size(); ++i) {
425     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
426   }
427 }
428 
TEST_F(AutoMixedPrecisionTest,PreserveFetches)429 TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
430   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
431   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
432   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
433   Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
434   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
435   Output deny1 = ops::Exp(s.WithOpName("deny1"), infer1);
436   Output clr2 = ops::Relu(s.WithOpName("clr2"), deny1);
437   Output allow2 = ops::MatMul(s.WithOpName("allow2"), clr2, clr2);
438   Output clr3 = ops::Relu(s.WithOpName("clr3"), allow2);
439   Output deny2 = ops::Exp(s.WithOpName("deny2"), clr3);
440   Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
441 
442   GrapplerItem item;
443   item.fetch = {"allow1", "clr2", "clr3"};
444   TF_CHECK_OK(s.ToGraphDef(&item.graph));
445   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
446 
447   AutoMixedPrecision optimizer;
448   GraphDef output;
449   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
450 
451   VLOG(1) << output.DebugString();
452 
453   GraphView output_view(&output);
454   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
455   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
456   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
457   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
458   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
459   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
460   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
461   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
462   EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_FLOAT);
463   EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
464   EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
465 
466   auto tensors = EvaluateNodes(output, item.fetch);
467   EXPECT_EQ(tensors.size(), tensors_expected.size());
468   EXPECT_EQ(tensors.size(), item.fetch.size());
469   for (int i = 0; i < item.fetch.size(); ++i) {
470     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-3);
471   }
472 }
473 
TEST_F(AutoMixedPrecisionTest,PreserveCPUNodes)474 TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) {
475   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
476   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
477   Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
478   Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
479   Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
480   Output allow2 =
481       ops::MatMul(s.WithOpName("allow2").WithDevice(
482                       "/job:localhost/replica:0/task:0/device:CPU:0"),
483                   infer1, infer1);
484   Output clr2 = ops::Relu(s.WithOpName("clr2"), allow2);
485   Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
486 
487   GrapplerItem item;
488   item.fetch = {"fetch"};
489   TF_CHECK_OK(s.ToGraphDef(&item.graph));
490   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
491 
492   AutoMixedPrecision optimizer;
493   GraphDef output;
494   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
495 
496   VLOG(1) << output.DebugString();
497 
498   GraphView output_view(&output);
499   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
500   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
501   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
502   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
503   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
504   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_FLOAT);
505   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
506 
507   auto tensors = EvaluateNodes(output, item.fetch);
508   EXPECT_EQ(tensors.size(), tensors_expected.size());
509   EXPECT_EQ(tensors.size(), item.fetch.size());
510   for (int i = 0; i < item.fetch.size(); ++i) {
511     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
512   }
513 }
514 
TEST_F(AutoMixedPrecisionTest,PreserveIdentityAfterVariable)515 TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) {
516   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
517   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
518   Output var1 = ops::Variable(s.WithOpName("var1"), {32, 32}, DT_FLOAT);
519   Output clr1 = ops::Identity(s.WithOpName("clr1"), var1);
520   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, clr1);
521   Output input2 = ops::Const(s.WithOpName("input2"), 1.f / 32, {32, 32});
522   Output clr2 = ops::Identity(s.WithOpName("clr2"), input2);
523   Output allow2 = ops::MatMul(s.WithOpName("allow2"), input, clr2);
524   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
525   Output fetch2 = ops::Identity(s.WithOpName("fetch2"), allow2);
526 
527   GrapplerItem item;
528   item.fetch = {"fetch1", "fetch2"};
529   TF_CHECK_OK(s.ToGraphDef(&item.graph));
530   auto var1_tensor =
531       GenerateConstantTensor<DT_FLOAT>(TensorShape({32, 32}), 3.141593f);
532   std::vector<std::pair<string, Tensor>> feed = {{"var1", var1_tensor}};
533   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
534 
535   AutoMixedPrecision optimizer;
536   GraphDef output;
537   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
538 
539   VLOG(1) << output.DebugString();
540 
541   GraphView output_view(&output);
542   EXPECT_EQ(output.node_size(), item.graph.node_size() + 5);
543   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
544   EXPECT_EQ(output_view.GetNode("var1")->attr().at("dtype").type(), DT_FLOAT);
545   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
546   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
547   EXPECT_EQ(output_view.GetNode("input2")->attr().at("dtype").type(), DT_FLOAT);
548   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
549   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
550 
551   auto tensors = EvaluateNodes(output, item.fetch, feed);
552   EXPECT_EQ(tensors.size(), tensors_expected.size());
553   EXPECT_EQ(tensors.size(), item.fetch.size());
554   for (int i = 0; i < item.fetch.size(); ++i) {
555     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-3);
556   }
557 }
558 
TEST_F(AutoMixedPrecisionTest,FusedBatchNorm)559 TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
560   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
561   // Uses NHWC data format because non-GPU execution does not support NCHW.
562   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {8, 56, 56, 16});
563   Output weight = ops::Const(s.WithOpName("weight"), 2.f, {3, 3, 16, 16});
564   Output scale = ops::Const(s.WithOpName("scale"), 3.f, {16});
565   Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16});
566   Output mean = ops::Const(s.WithOpName("mean"), 5.f, {0});
567   Output variance = ops::Const(s.WithOpName("variance"), 6.f, {0});
568   Output allow1 =
569       ops::Conv2D(s.WithOpName("allow1"), input, weight, {1, 1, 1, 1}, "SAME",
570                   ops::Conv2D::DataFormat("NHWC"));
571   auto fbn1_op =
572       ops::FusedBatchNorm(s.WithOpName("fbn1"), allow1, scale, offset, mean,
573                           variance, ops::FusedBatchNorm::DataFormat("NHWC"));
574   Output fbn1 = fbn1_op.y;
575   Output fbn1_rs1 = fbn1_op.reserve_space_1;
576   Output fbn1_rs2 = fbn1_op.reserve_space_2;
577   Output bng1 = ops::FusedBatchNormGrad(
578                     s.WithOpName("bng1"), fbn1, allow1, scale, fbn1_rs1,
579                     fbn1_rs2, ops::FusedBatchNormGrad::DataFormat("NHWC"))
580                     .x_backprop;
581   Output infer1 = ops::Add(s.WithOpName("infer1"), fbn1, bng1);
582   Output allow2 =
583       ops::Conv2D(s.WithOpName("allow2"), infer1, weight, {1, 1, 1, 1}, "SAME",
584                   ops::Conv2D::DataFormat("NHWC"));
585   Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
586 
587   GrapplerItem item;
588   item.fetch = {"fetch"};
589   TF_CHECK_OK(s.ToGraphDef(&item.graph));
590   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
591 
592   AutoMixedPrecision optimizer;
593   GraphDef output;
594   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
595 
596   VLOG(1) << output.DebugString();
597 
598   GraphView output_view(&output);
599   EXPECT_EQ(output.node_size(), item.graph.node_size() + 3);
600   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
601   EXPECT_EQ(output_view.GetNode("fbn1")->op(), "FusedBatchNormV2");
602   EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("T").type(), DT_HALF);
603   EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("U").type(), DT_FLOAT);
604   EXPECT_EQ(output_view.GetNode("bng1")->op(), "FusedBatchNormGradV2");
605   EXPECT_EQ(output_view.GetNode("bng1")->attr().at("T").type(), DT_HALF);
606   EXPECT_EQ(output_view.GetNode("bng1")->attr().at("U").type(), DT_FLOAT);
607   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
608   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
609 
610   auto tensors = EvaluateNodes(output, item.fetch);
611   EXPECT_EQ(tensors.size(), tensors_expected.size());
612   EXPECT_EQ(tensors.size(), item.fetch.size());
613   for (int i = 0; i < item.fetch.size(); ++i) {
614     test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-2);
615   }
616 }
617 
TEST_F(AutoMixedPrecisionTest,RepeatedAndListTypeAttrs)618 TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
619   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
620   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
621   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
622   auto clr1_op = ops::IdentityN(s.WithOpName("clr1"), {allow1, allow1, allow1});
623   Output infer1 =
624       ops::AddN(s.WithOpName("infer1"),
625                 {clr1_op.output[0], clr1_op.output[1], clr1_op.output[2]});
626   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
627   Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
628 
629   GrapplerItem item;
630   item.fetch = {"fetch"};
631   TF_CHECK_OK(s.ToGraphDef(&item.graph));
632   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
633 
634   AutoMixedPrecision optimizer;
635   GraphDef output;
636   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
637 
638   VLOG(1) << output.DebugString();
639 
640   GraphView output_view(&output);
641   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
642   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
643   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
644   for (auto type : output_view.GetNode("clr1")->attr().at("T").list().type()) {
645     EXPECT_EQ(type, DT_HALF);
646   }
647   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
648   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
649 
650   auto tensors = EvaluateNodes(output, item.fetch);
651   EXPECT_EQ(tensors.size(), tensors_expected.size());
652   EXPECT_EQ(tensors.size(), item.fetch.size());
653   for (int i = 0; i < item.fetch.size(); ++i) {
654     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
655   }
656 }
657 
TEST_F(AutoMixedPrecisionTest,ExistingCast)658 TEST_F(AutoMixedPrecisionTest, ExistingCast) {
659   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
660   Output input = ops::Const(s.WithOpName("input"), true, {32, 32});
661   Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_FLOAT);
662   Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
663   Output fetch = ops::Identity(s.WithOpName("fetch"), allow1);
664 
665   GrapplerItem item;
666   item.fetch = {"fetch"};
667   TF_CHECK_OK(s.ToGraphDef(&item.graph));
668   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
669 
670   AutoMixedPrecision optimizer;
671   GraphDef output;
672   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
673 
674   VLOG(1) << output.DebugString();
675 
676   GraphView output_view(&output);
677   EXPECT_EQ(output.node_size(), item.graph.node_size() + 1);
678   EXPECT_EQ(output_view.GetNode("cst1")->attr().at("SrcT").type(), DT_BOOL);
679   EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
680   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
681 
682   auto tensors = EvaluateNodes(output, item.fetch);
683   EXPECT_EQ(tensors.size(), tensors_expected.size());
684   EXPECT_EQ(tensors.size(), item.fetch.size());
685   for (int i = 0; i < item.fetch.size(); ++i) {
686     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
687   }
688 }
689 
TEST_F(AutoMixedPrecisionTest,RecurrentEdgeColorMismatch)690 TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) {
691   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
692   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
693   Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
694   Output ent1 =
695       ops::internal::Enter(s.WithOpName("ent1"), deny1, "loop1").output;
696   // Note that the second input is later replaced with "nxt1".
697   Output mrg1 = ops::Merge(s.WithOpName("mrg1"), {ent1, ent1}).output;
698   // For simplicity, the loop condition is constant false.
699   Output con1 = ops::Const(s.WithOpName("con1"), false, {});
700   Output lpc1 = ops::LoopCond(s.WithOpName("lpc1"), con1).output;
701   auto swt1 = ops::Switch(s.WithOpName("swt1"), mrg1, lpc1);
702   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), swt1.output_true);
703   Output allow1 = ops::MatMul(s.WithOpName("allow1"), infer1, infer1);
704   Output nxt1 = ops::NextIteration(s.WithOpName("nxt1"), allow1);
705   Output ext1 = ops::internal::Exit(s.WithOpName("ext1"), swt1.output_false);
706   Output fetch = ops::Identity(s.WithOpName("fetch"), ext1);
707   // Add a second merge node from the same NextIteration node. This case arises
708   // during graph optimization of some models.
709   auto mrg2 = ops::Merge(s.WithOpName("mrg2"), {ent1, nxt1});
710 
711   GrapplerItem item;
712   item.fetch = {"fetch"};
713   TF_CHECK_OK(s.ToGraphDef(&item.graph));
714   NodeMap node_map_original(&item.graph);
715   auto merge_node = node_map_original.GetNode("mrg1");
716   // Modify the graph to create a loop.
717   merge_node->set_input(1, "nxt1");
718   // Add a control edge to ensure the loop condition is inside the frame.
719   auto const_node = node_map_original.GetNode("con1");
720   const_node->add_input("^mrg1");
721   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
722 
723   AutoMixedPrecision optimizer;
724   GraphDef output;
725   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
726 
727   VLOG(1) << output.DebugString();
728 
729   GraphView output_view(&output);
730   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
731   // Note that mrg1 gets painted deny because it is between deny1 and infer1.
732   // This forces nxt1 and mrg2 to be painted deny as well (they would otherwise
733   // be painted allow because they are clear and have a direct path to allow1).
734   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
735   EXPECT_EQ(output_view.GetNode("ent1")->attr().at("T").type(), DT_FLOAT);
736   EXPECT_EQ(output_view.GetNode("mrg1")->attr().at("T").type(), DT_FLOAT);
737   EXPECT_EQ(output_view.GetNode("swt1")->attr().at("T").type(), DT_FLOAT);
738   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
739   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
740   EXPECT_EQ(output_view.GetNode("nxt1")->attr().at("T").type(), DT_FLOAT);
741   EXPECT_EQ(output_view.GetNode("ext1")->attr().at("T").type(), DT_FLOAT);
742   EXPECT_EQ(output_view.GetNode("mrg2")->attr().at("T").type(), DT_FLOAT);
743 
744   auto tensors = EvaluateNodes(output, item.fetch);
745   EXPECT_EQ(tensors.size(), tensors_expected.size());
746   EXPECT_EQ(tensors.size(), item.fetch.size());
747   for (int i = 0; i < item.fetch.size(); ++i) {
748     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
749   }
750 }
751 
TEST_F(AutoMixedPrecisionTest,TensorListSetGet)752 TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
753   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
754   tensorflow::Input shape = {32, 32};
755   auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
756   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
757   Output idx1 = ops::Const(s.WithOpName("idx1"), 1);
758   Output idx2 = ops::Const(s.WithOpName("idx2"), 2);
759   Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
760   auto tl1w1 =
761       ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
762   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
763   auto tl1w2 =
764       ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
765   // Ensure that TensorListResize doesn't cause any problems.
766   Output tl1rs =
767       ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
768   Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
769                                         shape, DT_FLOAT)
770                      .item;
771   Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
772   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
773   auto tl1w3 =
774       ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
775   Output tl1r2 =
776       ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
777                              shape, DT_FLOAT)
778           .item;
779   auto tl2 = ops::TensorListReserve(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
780   auto tl2w1 =
781       ops::TensorListSetItem(s.WithOpName("tl2w1"), tl2.handle, idx1, input);
782   Output tl2r1 =
783       ops::TensorListGetItem(s.WithOpName("tl2r1"), tl2w1.output_handle, idx1,
784                              shape, DT_FLOAT)
785           .item;
786   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
787   Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
788 
789   GrapplerItem item;
790   item.fetch = {"fetch1", "fetch2"};
791   TF_CHECK_OK(s.ToGraphDef(&item.graph));
792   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
793 
794   AutoMixedPrecision optimizer;
795   GraphDef output;
796   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
797 
798   VLOG(1) << output.DebugString();
799 
800   GraphView output_view(&output);
801   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
802   const char* type_key = "element_dtype";
803   EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
804   EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
805   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
806   EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
807   EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
808   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
809   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
810   EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
811   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
812   EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
813   EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
814 
815   auto tensors = EvaluateNodes(output, item.fetch);
816   EXPECT_EQ(tensors.size(), tensors_expected.size());
817   EXPECT_EQ(tensors.size(), item.fetch.size());
818   for (int i = 0; i < item.fetch.size(); ++i) {
819     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
820   }
821 }
822 
TEST_F(AutoMixedPrecisionTest,TensorListPushPop)823 TEST_F(AutoMixedPrecisionTest, TensorListPushPop) {
824   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
825   tensorflow::Input shape = {32, 32};
826   auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
827   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
828   auto tl1w1 =
829       ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, input);
830   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
831   auto tl1w2 = ops::TensorListPushBack(s.WithOpName("tl1w2"),
832                                        tl1w1.output_handle, allow1);
833   Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"),
834                                         tl1w2.output_handle, shape, DT_FLOAT)
835                      .tensor;
836   Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
837   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
838   auto tl1w3 =
839       ops::TensorListPushBack(s.WithOpName("tl1w3"), tl1.handle, allow2);
840   Output tl1r2 = ops::TensorListPopBack(s.WithOpName("tl1r2"),
841                                         tl1w3.output_handle, shape, DT_FLOAT)
842                      .tensor;
843   auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
844   auto tl2w1 =
845       ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, input);
846   Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
847                                         tl2w1.output_handle, shape, DT_FLOAT)
848                      .tensor;
849   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
850   Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
851 
852   GrapplerItem item;
853   item.fetch = {"fetch1", "fetch2"};
854   TF_CHECK_OK(s.ToGraphDef(&item.graph));
855   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
856 
857   AutoMixedPrecision optimizer;
858   GraphDef output;
859   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
860 
861   VLOG(1) << output.DebugString();
862 
863   GraphView output_view(&output);
864   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
865   const char* type_key = "element_dtype";
866   EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
867   EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
868   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
869   EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
870   EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
871   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
872   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
873   EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
874   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
875   EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
876   EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
877 
878   auto tensors = EvaluateNodes(output, item.fetch);
879   EXPECT_EQ(tensors.size(), tensors_expected.size());
880   EXPECT_EQ(tensors.size(), item.fetch.size());
881   for (int i = 0; i < item.fetch.size(); ++i) {
882     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
883   }
884 }
885 
TEST_F(AutoMixedPrecisionTest,TensorListFromTensor)886 TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) {
887   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
888   tensorflow::Input shape = {32};
889   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
890   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
891   auto tl1 = ops::TensorListFromTensor(s.WithOpName("tl1"), allow1, shape);
892   Output tl1r1 = ops::TensorListStack(s.WithOpName("tl1r1"), tl1.output_handle,
893                                       shape, DT_FLOAT)
894                      .tensor;
895   Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
896   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
897   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
898 
899   // This tests that a allow-painted object node (tl2) will force an unpainted
900   // client node (tl2w1) to be painted allow as well. (Without the force, tl2w1
901   // would remain unpainted, producing an invalid graph).
902   auto tl2 = ops::TensorListFromTensor(s.WithOpName("tl2"), allow1, shape);
903   auto tl2w1 =
904       ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.output_handle, input);
905 
906   GrapplerItem item;
907   item.fetch = {"fetch1"};
908   TF_CHECK_OK(s.ToGraphDef(&item.graph));
909   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
910 
911   AutoMixedPrecision optimizer;
912   GraphDef output;
913   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
914 
915   VLOG(1) << output.DebugString();
916 
917   GraphView output_view(&output);
918   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
919   const char* type_key = "element_dtype";
920   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
921   EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
922   EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
923   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
924   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
925   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
926   EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
927 
928   auto tensors = EvaluateNodes(output, item.fetch);
929   EXPECT_EQ(tensors.size(), tensors_expected.size());
930   EXPECT_EQ(tensors.size(), item.fetch.size());
931   for (int i = 0; i < item.fetch.size(); ++i) {
932     test::ExpectClose(tensors_expected[i], tensors[i], -1, 2e-4);
933   }
934 }
935 
TEST_F(AutoMixedPrecisionTest,TensorListPushBackBatchAndConcatLists)936 TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
937   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
938   tensorflow::Input shape = {32, 32};
939   auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
940   auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
941   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
942   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
943   Output tl1_tl2 =
944       ops::Stack(s.WithOpName("tl1_tl2"), {tl1.handle, tl2.handle});
945   Output allow1_allow1 =
946       ops::Stack(s.WithOpName("allow1_allow1"), {allow1, allow1});
947   auto tl12w1 = ops::TensorListPushBackBatch(s.WithOpName("tl12w1"), tl1_tl2,
948                                              allow1_allow1);
949   OutputList tl12w1_outputs =
950       ops::Split(s.WithOpName("tl12w1_outputs"), 0, tl12w1.output_handles, 2)
951           .output;
952   Output scalar_shape = ops::Const(s.WithOpName("scalar_shape"), 0, {0});
953   Output tl12w1_output0 = ops::Reshape(s.WithOpName("tl12w1_output0"),
954                                        tl12w1_outputs[0], scalar_shape);
955   Output tl12w1_output1 = ops::Reshape(s.WithOpName("tl12w1_output1"),
956                                        tl12w1_outputs[1], scalar_shape);
957   Output tl3 = ops::TensorListConcatLists(s.WithOpName("tl3"), tl12w1_output0,
958                                           tl12w1_output1, DT_FLOAT);
959   Output tl3r1 =
960       ops::TensorListPopBack(s.WithOpName("tl3r1"), tl3, shape, DT_FLOAT)
961           .tensor;
962   Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl3r1);
963   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
964   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
965 
966   GrapplerItem item;
967   item.fetch = {"fetch1"};
968   TF_CHECK_OK(s.ToGraphDef(&item.graph));
969   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
970 
971   AutoMixedPrecision optimizer;
972   GraphDef output;
973   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
974 
975   VLOG(1) << output.DebugString();
976 
977   GraphView output_view(&output);
978   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
979   const char* type_key = "element_dtype";
980   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
981   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
982   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
983   EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
984   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
985   EXPECT_EQ(output_view.GetNode("tl3")->attr().at(type_key).type(), DT_HALF);
986   EXPECT_EQ(output_view.GetNode("tl3r1")->attr().at(type_key).type(), DT_HALF);
987 
988   auto tensors = EvaluateNodes(output, item.fetch);
989   EXPECT_EQ(tensors.size(), tensors_expected.size());
990   EXPECT_EQ(tensors.size(), item.fetch.size());
991   for (int i = 0; i < item.fetch.size(); ++i) {
992     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
993   }
994 }
995 
TEST_F(AutoMixedPrecisionTest,TensorListThroughFunction)996 TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
997   // This test passes a tensor list handle through a function with its own
998   // Tensor List ops inside to test that the types are not changed to a
999   // conflicting state.
1000   // A separate Tensor List cluster is added to test that it is still changed to
1001   // DT_HALF.
1002   FunctionDefLibrary function_lib;
1003   const Tensor kShape = test::AsTensor<int32>({32, 32});
1004   FunctionDef func1 = FunctionDefHelper::Define(
1005       "Func1", {"ihandle: variant", "x: float"},
1006       {"ohandle: variant", "y: float"}, {},
1007       {
1008           {{"tl1w1_handle"},
1009            "TensorListPushBack",
1010            {"ihandle", "x"},
1011            {{"element_dtype", DT_FLOAT}}},
1012           {{"shape"}, "Const", {}, {{"value", kShape}, {"dtype", DT_INT32}}},
1013           {{"tl1r1_handle", "tl1r1_data"},
1014            "TensorListPopBack",
1015            {"tl1w1_handle", "shape"},
1016            {{"element_dtype", DT_FLOAT}}},
1017           {{"ohandle"}, "Identity", {"tl1r1_handle"}, {{"T", DT_VARIANT}}},
1018           {{"y"}, "Identity", {"tl1r1_data"}, {{"T", DT_FLOAT}}},
1019       });
1020   function_lib.add_function()->Swap(&func1);
1021 
1022   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1023   TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib));
1024   tensorflow::Input shape = {32, 32};
1025   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1026   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
1027   Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
1028   auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
1029   auto tl1w1 =
1030       ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, infer1);
1031   auto _infer1 = tensorflow::ops::AsNodeOut(s, infer1);
1032   auto _tl1w1_handle = tensorflow::ops::AsNodeOut(s, tl1w1.output_handle);
1033   auto builder =
1034       tensorflow::NodeBuilder("Func1", "Func1", s.graph()->op_registry());
1035   tensorflow::Node* func1_op;
1036   TF_CHECK_OK(builder.Input(_tl1w1_handle)
1037                   .Input(_infer1)
1038                   .Finalize(s.graph(), &func1_op));
1039   Output func1_handle(func1_op, 0);
1040   Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"), func1_handle,
1041                                         shape, DT_FLOAT)
1042                      .tensor;
1043   auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
1044   auto tl2w1 =
1045       ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, infer1);
1046   Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
1047                                         tl2w1.output_handle, shape, DT_FLOAT)
1048                      .tensor;
1049   Output allow2 = ops::MatMul(s.WithOpName("allow2"), tl1r1, tl2r1);
1050   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
1051 
1052   GrapplerItem item;
1053   item.fetch = {"fetch1"};
1054   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1055   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1056 
1057   AutoMixedPrecision optimizer;
1058   GraphDef output;
1059   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1060 
1061   VLOG(1) << output.DebugString();
1062 
1063   GraphView output_view(&output);
1064   const char* type_key = "element_dtype";
1065   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
1066   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
1067   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
1068   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
1069   EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
1070   EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_HALF);
1071 
1072   auto tensors = EvaluateNodes(output, item.fetch);
1073   EXPECT_EQ(tensors.size(), tensors_expected.size());
1074   EXPECT_EQ(tensors.size(), item.fetch.size());
1075   for (int i = 0; i < item.fetch.size(); ++i) {
1076     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
1077   }
1078 }
1079 
GetCudaVersion(const Cluster & cluster)1080 int GetCudaVersion(const Cluster& cluster) {
1081   auto devices = cluster.GetDevices();
1082   for (const auto& device : devices) {
1083     const DeviceProperties& device_properties = device.second;
1084     if (device_properties.type() == "GPU") {
1085       const auto& device_env = device_properties.environment();
1086       auto it = device_env.find("cuda");
1087       if (it != device_env.end()) {
1088         string cuda_version_str = it->second;
1089         return std::stoi(cuda_version_str);
1090       }
1091     }
1092   }
1093   return 0;
1094 }
1095 
IsSupportedGPU(const Cluster & cluster)1096 bool IsSupportedGPU(const Cluster& cluster) {
1097 #ifdef GOOGLE_CUDA
1098   return GetCudaVersion(cluster) >= 9010;
1099 #else
1100   return true;
1101 #endif
1102 }
1103 
TEST_F(AutoMixedPrecisionTest,BatchMatMul)1104 TEST_F(AutoMixedPrecisionTest, BatchMatMul) {
1105   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1106   Output input = ops::Const(s.WithOpName("input"), 1.f / 33, {64, 32, 32});
1107   Output allow1 = ops::BatchMatMul(s.WithOpName("allow1"), input, input);
1108   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
1109 
1110   GrapplerItem item;
1111   item.fetch = {"fetch1"};
1112   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1113   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1114 
1115   AutoMixedPrecision optimizer;
1116   GraphDef output;
1117   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1118 
1119   VLOG(1) << output.DebugString();
1120 
1121   GraphView output_view(&output);
1122   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1123   if (IsSupportedGPU(*virtual_cluster_.get())) {
1124     EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1125     EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
1126   } else {
1127     EXPECT_EQ(output.node_size(), item.graph.node_size());
1128     EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
1129   }
1130 
1131   auto tensors = EvaluateNodes(output, item.fetch);
1132   EXPECT_EQ(tensors.size(), tensors_expected.size());
1133   EXPECT_EQ(tensors.size(), item.fetch.size());
1134   for (int i = 0; i < item.fetch.size(); ++i) {
1135     test::ExpectClose(tensors_expected[i], tensors[i], -1, 3.0e-3);
1136   }
1137 }
1138 
TEST_F(AutoMixedPrecisionTest,EluOp)1139 TEST_F(AutoMixedPrecisionTest, EluOp) {
1140   TestSimpleUnaryInferOp(
1141       -5, 5, 1.0e-3, 1.0e-3,
1142       [](const tensorflow::Scope& scope, Output input) -> Output {
1143         return ops::Elu(scope, input);
1144       });
1145 }
1146 
TEST_F(AutoMixedPrecisionTest,ErfOp)1147 TEST_F(AutoMixedPrecisionTest, ErfOp) {
1148   TestSimpleUnaryInferOp(
1149       -5, 5, 1.0e-3, -1,
1150       [](const tensorflow::Scope& scope, Output input) -> Output {
1151         return ops::Erf(scope, input);
1152       });
1153 }
1154 
TEST_F(AutoMixedPrecisionTest,ErfcOp)1155 TEST_F(AutoMixedPrecisionTest, ErfcOp) {
1156   TestSimpleUnaryInferOp(
1157       -5, 5, 1.0e-3, -1,
1158       [](const tensorflow::Scope& scope, Output input) -> Output {
1159         return ops::Erfc(scope, input);
1160       });
1161 }
1162 
TEST_F(AutoMixedPrecisionTest,InvOp)1163 TEST_F(AutoMixedPrecisionTest, InvOp) {
1164   TestSimpleUnaryInferOp(
1165       0.01, 10, -1, 1.0e-3,
1166       [](const tensorflow::Scope& scope, Output input) -> Output {
1167         return ops::Inv(scope, input);
1168       });
1169 }
1170 
TEST_F(AutoMixedPrecisionTest,LogOp)1171 TEST_F(AutoMixedPrecisionTest, LogOp) {
1172   TestSimpleUnaryInferOp(
1173       0.01, 10, 1.0e-3, 2.0e-3,
1174       [](const tensorflow::Scope& scope, Output input) -> Output {
1175         return ops::Log(scope, input);
1176       });
1177 }
1178 
TEST_F(AutoMixedPrecisionTest,Log1pOp)1179 TEST_F(AutoMixedPrecisionTest, Log1pOp) {
1180   TestSimpleUnaryInferOp(
1181       -0.99, 9, 1.0e-3, 5.0e-3,
1182       [](const tensorflow::Scope& scope, Output input) -> Output {
1183         return ops::Log1p(scope, input);
1184       });
1185 }
1186 
TEST_F(AutoMixedPrecisionTest,LogSoftmaxOp)1187 TEST_F(AutoMixedPrecisionTest, LogSoftmaxOp) {
1188   TestSimpleUnaryInferOp(
1189       -8, 8, -1, 1.0e-2,
1190       [](const tensorflow::Scope& scope, Output input) -> Output {
1191         return ops::LogSoftmax(scope, input);
1192       });
1193 }
1194 
TEST_F(AutoMixedPrecisionTest,ReciprocalOp)1195 TEST_F(AutoMixedPrecisionTest, ReciprocalOp) {
1196   TestSimpleUnaryInferOp(
1197       0.01, 10, -1, 1.0e-3,
1198       [](const tensorflow::Scope& scope, Output input) -> Output {
1199         return ops::Reciprocal(scope, input);
1200       });
1201 }
1202 
TEST_F(AutoMixedPrecisionTest,SigmoidOp)1203 TEST_F(AutoMixedPrecisionTest, SigmoidOp) {
1204   TestSimpleUnaryInferOp(
1205       -5, 5, 1.0e-3, -1,
1206       [](const tensorflow::Scope& scope, Output input) -> Output {
1207         return ops::Sigmoid(scope, input);
1208       });
1209 }
1210 
TEST_F(AutoMixedPrecisionTest,SoftmaxOp)1211 TEST_F(AutoMixedPrecisionTest, SoftmaxOp) {
1212   TestSimpleUnaryInferOp(
1213       -8, 8, 2.0e-3, -1,
1214       [](const tensorflow::Scope& scope, Output input) -> Output {
1215         return ops::Softmax(scope, input);
1216       });
1217 }
1218 
TEST_F(AutoMixedPrecisionTest,SoftplusOp)1219 TEST_F(AutoMixedPrecisionTest, SoftplusOp) {
1220   TestSimpleUnaryInferOp(
1221       -5, 5, 1.0e-3, 1.0e-3,
1222       [](const tensorflow::Scope& scope, Output input) -> Output {
1223         return ops::Softplus(scope, input);
1224       });
1225 }
1226 
TEST_F(AutoMixedPrecisionTest,SqrtOp)1227 TEST_F(AutoMixedPrecisionTest, SqrtOp) {
1228   TestSimpleUnaryInferOp(
1229       0, 10, 1.0e-3, 1.0e-3,
1230       [](const tensorflow::Scope& scope, Output input) -> Output {
1231         return ops::Sqrt(scope, input);
1232       });
1233 }
1234 
TEST_F(AutoMixedPrecisionTest,TanhOp)1235 TEST_F(AutoMixedPrecisionTest, TanhOp) {
1236   TestSimpleUnaryInferOp(
1237       -5, 5, 1.0e-3, -1,
1238       [](const tensorflow::Scope& scope, Output input) -> Output {
1239         return ops::Tanh(scope, input);
1240       });
1241 }
1242 
1243 class AutoMixedPrecisionCpuTest : public GrapplerTest {
1244  protected:
SetUp()1245   void SetUp() override {
1246     virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0));
1247     TF_CHECK_OK(virtual_cluster_->Provision());
1248   }
TearDown()1249   void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
1250 
1251   std::unique_ptr<Cluster> virtual_cluster_;
1252 };
1253 
TEST_F(AutoMixedPrecisionCpuTest,Simple)1254 TEST_F(AutoMixedPrecisionCpuTest, Simple) {
1255   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1256       "/job:localhost/replica:0/task:0/device:CPU:0");
1257   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1258   Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
1259   Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
1260   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
1261   Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
1262   Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
1263   Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
1264   Output infer2 = ops::Log(s.WithOpName("infer2"), clr3);
1265   Output clr4 = ops::Relu(s.WithOpName("clr4"), infer2);
1266   Output deny2 = ops::SparseMatMul(s.WithOpName("deny2"), clr4, clr4);
1267   Output clr5 = ops::Relu(s.WithOpName("clr5"), deny2);
1268   Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
1269 
1270   GrapplerItem item;
1271   item.fetch = {"fetch"};
1272   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1273   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1274 
1275   AutoMixedPrecision optimizer{AutoMixedPrecisionMode::CPU};
1276   GraphDef output;
1277   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1278 
1279   VLOG(1) << output.DebugString();
1280 
1281   const int expected_cast_ops = 9;
1282   EXPECT_EQ(output.node_size(), item.graph.node_size() + expected_cast_ops);
1283 
1284   GraphView output_view(&output);
1285   // Matmul is a FP32 op now
1286   auto matmul_op = output_view.GetNode("allow1");
1287   EXPECT_EQ(matmul_op->attr().at("T").type(), DT_FLOAT);
1288   for (auto edge : output_view.GetFaninEdges(*matmul_op, false)) {
1289     EXPECT_EQ(edge.src.node->op(), "Cast");
1290     EXPECT_EQ(edge.src.node->attr().at("SrcT").type(), DT_HALF);
1291     EXPECT_EQ(edge.src.node->attr().at("DstT").type(), DT_FLOAT);
1292   }
1293   for (auto edge : output_view.GetFanoutEdges(*matmul_op, false)) {
1294     EXPECT_EQ(edge.dst.node->op(), "Cast");
1295     EXPECT_EQ(edge.dst.node->attr().at("SrcT").type(), DT_FLOAT);
1296     EXPECT_EQ(edge.dst.node->attr().at("DstT").type(), DT_HALF);
1297   }
1298 }
1299 
TEST_F(AutoMixedPrecisionCpuTest,MixedFanout)1300 TEST_F(AutoMixedPrecisionCpuTest, MixedFanout) {
1301   // Test when an FP16 allowed node has a mixed fanout of FP16 allowed node and
1302   // FP32 node.
1303   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1304       "/job:localhost/replica:0/task:0/device:CPU:0");
1305   Output input1 = ops::Const(s.WithOpName("input1"), 1.f / 32, {32, 32});
1306   Output input2 = ops::Const(s.WithOpName("input2"), 2.f / 32, {32, 32});
1307   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input1, input2);
1308   Output allow2 = ops::MatMul(s.WithOpName("allow2"), allow1, input2);
1309   Output deny = ops::Exp(s.WithOpName("deny"), allow1);
1310   Output infer = ops::Add(s.WithOpName("infer"), deny, allow2);
1311   Output fetch = ops::Identity(s.WithOpName("fetch"), infer);
1312 
1313   GrapplerItem item;
1314   item.fetch = {"fetch"};
1315   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1316   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1317 
1318   AutoMixedPrecision optimizer{AutoMixedPrecisionMode::CPU};
1319   GraphDef output;
1320   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1321 
1322   VLOG(1) << output.DebugString();
1323 
1324   const int expected_cast_ops = 10;
1325   EXPECT_EQ(output.node_size(), item.graph.node_size() + expected_cast_ops);
1326 
1327   GraphView output_view(&output);
1328   auto allow1_op = output_view.GetNode("allow1");
1329   for (auto edge : output_view.GetFaninEdges(*allow1_op, false)) {
1330     EXPECT_EQ(edge.src.node->op(), "Cast");
1331     EXPECT_EQ(edge.src.node->attr().at("SrcT").type(), DT_HALF);
1332     EXPECT_EQ(edge.src.node->attr().at("DstT").type(), DT_FLOAT);
1333   }
1334   for (auto edge : output_view.GetFanoutEdges(*allow1_op, false)) {
1335     EXPECT_EQ(edge.dst.node->op(), "Cast");
1336     EXPECT_EQ(edge.dst.node->attr().at("SrcT").type(), DT_FLOAT);
1337     EXPECT_EQ(edge.dst.node->attr().at("DstT").type(), DT_HALF);
1338   }
1339   auto deny_op = output_view.GetNode("deny");
1340   for (auto edge : output_view.GetFaninEdges(*deny_op, false)) {
1341     EXPECT_EQ(edge.src.node->op(), "Cast");
1342     EXPECT_EQ(edge.src.node->attr().at("SrcT").type(), DT_HALF);
1343     EXPECT_EQ(edge.src.node->attr().at("DstT").type(), DT_FLOAT);
1344   }
1345   for (auto edge : output_view.GetFanoutEdges(*deny_op, false)) {
1346     EXPECT_NE(edge.dst.node->op(), "Cast");
1347   }
1348 }
1349 
1350 class AutoMixedPrecisionSimulateGpuTest : public GrapplerTest {
1351  protected:
SetUp()1352   void SetUp() override {
1353     std::unordered_map<string, DeviceProperties> devices;
1354     DeviceProperties cpu_device;
1355     cpu_device.set_type("CPU");
1356     cpu_device.set_frequency(1000);
1357     cpu_device.set_num_cores(4);
1358     cpu_device.set_memory_size(1024 * 1024);
1359     devices["/job:localhost/replica:0/task:0/device:CPU:0"] = cpu_device;
1360     // Explicitly creating machine without GPU.
1361     virtual_cluster_.reset(new VirtualCluster(devices));
1362     TF_CHECK_OK(virtual_cluster_->Provision());
1363   }
TearDown()1364   void TearDown() override {
1365     unsetenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU");
1366     TF_CHECK_OK(virtual_cluster_->Shutdown());
1367   }
1368 
1369   std::unique_ptr<Cluster> virtual_cluster_;
1370 
TestSimple(tensorflow::Scope s,bool is_optimized)1371   void TestSimple(tensorflow::Scope s, bool is_optimized) {
1372     Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1373     Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
1374     Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
1375     Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
1376     Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
1377     Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
1378     Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
1379     Output infer2 = ops::Log(s.WithOpName("infer2"), clr3);
1380     Output clr4 = ops::Relu(s.WithOpName("clr4"), infer2);
1381     Output deny2 = ops::SparseMatMul(s.WithOpName("deny2"), clr4, clr4);
1382     Output clr5 = ops::Relu(s.WithOpName("clr5"), deny2);
1383     Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
1384 
1385     GrapplerItem item;
1386     item.fetch = {"fetch"};
1387     TF_CHECK_OK(s.ToGraphDef(&item.graph));
1388     auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1389 
1390     GraphDef output;
1391     AutoMixedPrecision optimizer;
1392     TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1393 
1394     VLOG(1) << output.DebugString();
1395 
1396     GraphView output_view(&output);
1397     DataType expected_data_type = is_optimized ? DT_HALF : DT_FLOAT;
1398     int expected_graph_size =
1399         is_optimized ? item.graph.node_size() + 2 : item.graph.node_size();
1400 
1401     EXPECT_EQ(output.node_size(), expected_graph_size);
1402     EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(),
1403               DT_FLOAT);
1404     EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
1405     EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
1406     EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
1407     EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(),
1408               expected_data_type);
1409     EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(),
1410               expected_data_type);
1411     EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(),
1412               expected_data_type);
1413     EXPECT_EQ(output_view.GetNode("infer2")->attr().at("T").type(), DT_FLOAT);
1414     EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
1415     EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Ta").type(), DT_FLOAT);
1416     EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Tb").type(), DT_FLOAT);
1417     EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
1418 
1419     auto tensors = EvaluateNodes(output, item.fetch);
1420     EXPECT_EQ(tensors.size(), tensors_expected.size());
1421     EXPECT_EQ(tensors.size(), item.fetch.size());
1422     for (int i = 0; i < item.fetch.size(); ++i) {
1423       test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
1424     }
1425   }
1426 };
1427 
TEST_F(AutoMixedPrecisionSimulateGpuTest,Simple_NoGpu)1428 TEST_F(AutoMixedPrecisionSimulateGpuTest, Simple_NoGpu) {
1429   TestSimple(tensorflow::Scope::NewRootScope(), /* is_optimized= */ false);
1430 }
1431 
TEST_F(AutoMixedPrecisionSimulateGpuTest,Simple_SimulatedGpu)1432 TEST_F(AutoMixedPrecisionSimulateGpuTest, Simple_SimulatedGpu) {
1433   setenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU", "true",
1434          1 /* replace */);
1435   TestSimple(tensorflow::Scope::NewRootScope(), /* is_optimized= */ true);
1436 }
1437 
TEST_F(AutoMixedPrecisionSimulateGpuTest,Simple_SimulatedGpu_CpuScope)1438 TEST_F(AutoMixedPrecisionSimulateGpuTest, Simple_SimulatedGpu_CpuScope) {
1439   setenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU", "true",
1440          1 /* replace */);
1441   TestSimple(tensorflow::Scope::NewRootScope().WithDevice(
1442                  "/job:localhost/replica:0/task:0/device:CPU:0"),
1443              /* is_optimized= */ false);
1444 }
1445 
1446 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1447 
1448 #if INTEL_MKL
1449 
1450 class AutoMixedPrecisionMklTest : public GrapplerTest {
1451  protected:
SetUp()1452   void SetUp() override {
1453     virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0));
1454     TF_CHECK_OK(virtual_cluster_->Provision());
1455   }
TearDown()1456   void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
1457 
1458   std::unique_ptr<Cluster> virtual_cluster_;
1459 };
1460 
TEST_F(AutoMixedPrecisionMklTest,AlreadyBf16)1461 TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) {
1462   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1463       "/job:localhost/replica:0/task:0/device:CPU:0");
1464   Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
1465   Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_BFLOAT16);
1466   Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
1467   Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
1468   Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
1469   Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
1470   Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
1471 
1472   GrapplerItem item;
1473   item.fetch = {"fetch"};
1474   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1475   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1476 
1477   AutoMixedPrecision optimizer{AutoMixedPrecisionMode::BF16};
1478   GraphDef output;
1479   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1480   VLOG(1) << output.DebugString();
1481 
1482   VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
1483   GraphView output_view(&output);
1484   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1485   EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_BFLOAT16);
1486   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1487   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_BFLOAT16);
1488   EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_BFLOAT16);
1489   EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
1490   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
1491 
1492   auto tensors = EvaluateNodes(output, item.fetch);
1493   EXPECT_EQ(tensors.size(), tensors_expected.size());
1494   EXPECT_EQ(tensors.size(), item.fetch.size());
1495   for (int i = 0; i < item.fetch.size(); ++i) {
1496     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
1497   }
1498 }
1499 
TEST_F(AutoMixedPrecisionMklTest,Simple)1500 TEST_F(AutoMixedPrecisionMklTest, Simple) {
1501   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1502       "/job:localhost/replica:0/task:0/device:CPU:0");
1503   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1504   Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
1505   Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
1506   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
1507   Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
1508   Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
1509   Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
1510   Output deny2 = ops::Log(s.WithOpName("deny2"), clr3);
1511   Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
1512   Output deny3 = ops::SparseMatMul(s.WithOpName("deny3"), clr4, clr4);
1513   Output clr5 = ops::Relu(s.WithOpName("clr5"), deny3);
1514   Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
1515 
1516   GrapplerItem item;
1517   item.fetch = {"fetch"};
1518   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1519   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1520 
1521   AutoMixedPrecision optimizer{AutoMixedPrecisionMode::BF16};
1522   GraphDef output;
1523   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1524 
1525   VLOG(1) << output.DebugString();
1526 
1527   GraphView output_view(&output);
1528   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1529   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1530   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
1531   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
1532   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
1533   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_BFLOAT16);
1534   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1535   EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_BFLOAT16);
1536   EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
1537   EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
1538   EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Ta").type(), DT_FLOAT);
1539   EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Tb").type(), DT_FLOAT);
1540   EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
1541 
1542   auto tensors = EvaluateNodes(output, item.fetch);
1543   EXPECT_EQ(tensors.size(), tensors_expected.size());
1544   EXPECT_EQ(tensors.size(), item.fetch.size());
1545   for (int i = 0; i < item.fetch.size(); ++i) {
1546     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
1547   }
1548 }
1549 
TEST_F(AutoMixedPrecisionMklTest,TensorListSetGet)1550 TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
1551   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1552       "/job:localhost/replica:0/task:0/device:CPU:0");
1553   tensorflow::Input shape = {32, 32};
1554   auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
1555   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1556   Output idx1 = ops::Const(s.WithOpName("idx1"), 1);
1557   Output idx2 = ops::Const(s.WithOpName("idx2"), 2);
1558   Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
1559   auto tl1w1 =
1560       ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
1561   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
1562   auto tl1w2 =
1563       ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
1564   // Ensure that TensorListResize doesn't cause any problems.
1565   Output tl1rs =
1566       ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
1567   Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
1568                                         shape, DT_FLOAT)
1569                      .item;
1570   Output infer1 = ops::Mul(s.WithOpName("infer1"), tl1r1, tl1r1);
1571   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
1572   auto tl1w3 =
1573       ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
1574   Output tl1r2 =
1575       ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
1576                              shape, DT_FLOAT)
1577           .item;
1578   auto tl2 = ops::TensorListReserve(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
1579   auto tl2w1 =
1580       ops::TensorListSetItem(s.WithOpName("tl2w1"), tl2.handle, idx1, input);
1581   Output tl2r1 =
1582       ops::TensorListGetItem(s.WithOpName("tl2r1"), tl2w1.output_handle, idx1,
1583                              shape, DT_FLOAT)
1584           .item;
1585   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
1586   Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
1587 
1588   GrapplerItem item;
1589   item.fetch = {"fetch1", "fetch2"};
1590   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1591   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1592 
1593   AutoMixedPrecision optimizer{AutoMixedPrecisionMode::BF16};
1594   GraphDef output;
1595   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1596 
1597   VLOG(1) << output.DebugString();
1598 
1599   GraphView output_view(&output);
1600   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1601   const char* type_key = "element_dtype";
1602   EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(),
1603             DT_BFLOAT16);
1604   EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(),
1605             DT_BFLOAT16);
1606   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1607   EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(),
1608             DT_BFLOAT16);
1609   EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(),
1610             DT_BFLOAT16);
1611   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_BFLOAT16);
1612   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_BFLOAT16);
1613   EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(),
1614             DT_BFLOAT16);
1615   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
1616   EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
1617   EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
1618 
1619   auto tensors = EvaluateNodes(output, item.fetch);
1620   EXPECT_EQ(tensors.size(), tensors_expected.size());
1621   EXPECT_EQ(tensors.size(), item.fetch.size());
1622   for (int i = 0; i < item.fetch.size(); ++i) {
1623     test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-2);
1624   }
1625 }
1626 
TEST_F(AutoMixedPrecisionMklTest,InferFollowUpStreamAllow)1627 TEST_F(AutoMixedPrecisionMklTest, InferFollowUpStreamAllow) {
1628   if (!IsMKLEnabled())
1629     GTEST_SKIP() << "Test only applicable to MKL auto-mixed precision.";
1630   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1631       "/job:localhost/replica:0/task:0/device:CPU:0");
1632   Output input1 = ops::Const(s.WithOpName("input1"), 1.f / 32, {8, 56, 56, 16});
1633   Output weight = ops::Const(s.WithOpName("weight"), 2.f, {3, 3, 16, 16});
1634   Output allow =
1635       ops::Conv2D(s.WithOpName("allow"), input1, weight, {1, 1, 1, 1}, "SAME",
1636                   ops::Conv2D::DataFormat("NHWC"));
1637   Output input2 = ops::Const(s.WithOpName("input2"), 1.f / 32, {16});
1638   Output infer = ops::BiasAdd(s.WithOpName("infer"), allow, input2);
1639   Output clr = ops::Relu(s.WithOpName("clr"), infer);
1640   Output fetch = ops::Identity(s.WithOpName("fetch"), clr);
1641 
1642   GrapplerItem item;
1643   item.fetch = {"fetch"};
1644   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1645   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1646 
1647   AutoMixedPrecision optimizer{AutoMixedPrecisionMode::BF16};
1648   GraphDef output;
1649   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1650 
1651   VLOG(1) << output.DebugString();
1652 
1653   GraphView output_view(&output);
1654   EXPECT_EQ(output.node_size(), item.graph.node_size() + 4);
1655   EXPECT_EQ(output_view.GetNode("input1")->attr().at("dtype").type(), DT_FLOAT);
1656   EXPECT_EQ(output_view.GetNode("weight")->attr().at("dtype").type(), DT_FLOAT);
1657   EXPECT_EQ(output_view.GetNode("input2")->attr().at("dtype").type(), DT_FLOAT);
1658   EXPECT_EQ(output_view.GetNode("allow")->attr().at("T").type(), DT_BFLOAT16);
1659   EXPECT_EQ(output_view.GetNode("infer")->attr().at("T").type(), DT_BFLOAT16);
1660   EXPECT_EQ(output_view.GetNode("clr")->attr().at("T").type(), DT_BFLOAT16);
1661 
1662   auto tensors = EvaluateNodes(output, item.fetch);
1663   EXPECT_EQ(tensors.size(), tensors_expected.size());
1664   EXPECT_EQ(tensors.size(), item.fetch.size());
1665   for (int i = 0; i < item.fetch.size(); ++i) {
1666     test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-2);
1667   }
1668 }
1669 
TEST_F(AutoMixedPrecisionMklTest,InferFollowUpStreamDeny)1670 TEST_F(AutoMixedPrecisionMklTest, InferFollowUpStreamDeny) {
1671   if (!IsMKLEnabled())
1672     GTEST_SKIP() << "Test only applicable to MKL auto-mixed precision.";
1673   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1674       "/job:localhost/replica:0/task:0/device:CPU:0");
1675   Output input1 = ops::Const(s.WithOpName("input1"), 1.f / 32, {8, 56, 56, 16});
1676   Output input2 = ops::Const(s.WithOpName("input2"), 1.f, {16});
1677   Output input3 = ops::Const(s.WithOpName("input3"), 1.f / 32, {16});
1678   Output deny = ops::Pow(s.WithOpName("deny"), input1, input2);
1679   Output infer = ops::BiasAdd(s.WithOpName("infer"), deny, input3);
1680   Output clr = ops::Relu(s.WithOpName("clr"), infer);
1681   Output fetch = ops::Identity(s.WithOpName("fetch"), clr);
1682 
1683   GrapplerItem item;
1684   item.fetch = {"fetch"};
1685   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1686   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1687 
1688   AutoMixedPrecision optimizer{AutoMixedPrecisionMode::BF16};
1689   GraphDef output;
1690   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1691 
1692   VLOG(1) << output.DebugString();
1693 
1694   GraphView output_view(&output);
1695   EXPECT_EQ(output.node_size(), item.graph.node_size());
1696   EXPECT_EQ(output_view.GetNode("input1")->attr().at("dtype").type(), DT_FLOAT);
1697   EXPECT_EQ(output_view.GetNode("input2")->attr().at("dtype").type(), DT_FLOAT);
1698   EXPECT_EQ(output_view.GetNode("input3")->attr().at("dtype").type(), DT_FLOAT);
1699   EXPECT_EQ(output_view.GetNode("deny")->attr().at("T").type(), DT_FLOAT);
1700   EXPECT_EQ(output_view.GetNode("infer")->attr().at("T").type(), DT_FLOAT);
1701   EXPECT_EQ(output_view.GetNode("clr")->attr().at("T").type(), DT_FLOAT);
1702 
1703   auto tensors = EvaluateNodes(output, item.fetch);
1704   EXPECT_EQ(tensors.size(), tensors_expected.size());
1705   EXPECT_EQ(tensors.size(), item.fetch.size());
1706   for (int i = 0; i < item.fetch.size(); ++i) {
1707     test::ExpectClose(tensors_expected[i], tensors[i]);
1708   }
1709 }
1710 #endif  // INTEL_MKL
1711 
1712 }  // namespace
1713 }  // namespace grappler
1714 }  // namespace tensorflow
1715 
1716 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || INTEL_MKL
1717