• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || INTEL_MKL
17 
18 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
19 
20 #include <utility>
21 #include <vector>
22 
23 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
24 #include "tensorflow/cc/ops/list_ops.h"
25 #include "tensorflow/cc/ops/math_ops.h"
26 #include "tensorflow/cc/ops/standard_ops.h"
27 #include "tensorflow/core/framework/function_testlib.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/tensor_testutil.h"
30 #include "tensorflow/core/grappler/clusters/single_machine.h"
31 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
32 #include "tensorflow/core/grappler/devices.h"
33 #include "tensorflow/core/grappler/graph_view.h"
34 #include "tensorflow/core/grappler/utils/grappler_test.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/lib/random/random.h"
37 
38 // TODO(benbarsdell): Improve the numerical checks in these tests. The tests
39 // were originally written only to check the graph coloring, so the graphs do
40 // not have particularly realistic numerical behavior.
41 
42 namespace tensorflow {
43 namespace grappler {
44 namespace {
45 
46 template <DataType DTYPE>
GenerateIdentityMatrix(int64_t height,int64_t width)47 Tensor GenerateIdentityMatrix(int64_t height, int64_t width) {
48   typedef typename EnumToDataType<DTYPE>::Type T;
49   Tensor tensor(DTYPE, TensorShape{height, width});
50   for (int64_t i = 0; i < height; ++i) {
51     for (int64_t j = 0; j < width; ++j) {
52       tensor.matrix<T>()(i, j) = i == j;
53     }
54   }
55   return tensor;
56 }
57 
58 template <DataType DTYPE>
GenerateRandomTensorInRange(const TensorShape & shape,double minval,double maxval)59 Tensor GenerateRandomTensorInRange(const TensorShape& shape, double minval,
60                                    double maxval) {
61   typedef typename EnumToDataType<DTYPE>::Type T;
62   Tensor tensor(DTYPE, shape);
63   for (auto i = 0; i < tensor.NumElements(); i++)
64     tensor.flat<T>()(i) =
65         (random::New64() % 65536 / 65536.0) * (maxval - minval) + minval;
66   return tensor;
67 }
68 
VerifyGraphsEquivalent(const GraphDef & original_graph,const GraphDef & optimized_graph,const string & func)69 void VerifyGraphsEquivalent(const GraphDef& original_graph,
70                             const GraphDef& optimized_graph,
71                             const string& func) {
72   EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
73   GraphView optimized_view(&optimized_graph);
74   for (int i = 0; i < original_graph.node_size(); ++i) {
75     const NodeDef& original = original_graph.node(i);
76     const NodeDef& optimized = *optimized_view.GetNode(original.name());
77     EXPECT_EQ(original.name(), optimized.name()) << func;
78     EXPECT_EQ(original.op(), optimized.op()) << func;
79     EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
80     if (original.input_size() == optimized.input_size()) {
81       for (int j = 0; j < original.input_size(); ++j) {
82         EXPECT_EQ(original.input(j), optimized.input(j)) << func;
83       }
84     }
85   }
86 }
87 
88 // Currently, this test suite only passes when TensorFlow passes with CUDA/HIP,
89 // because otherwise the optimizer will not turn clearlist nodes to float16.
90 // When looking at clearlist nodes, this optimizer checks if the nodes have a
91 // float16 GPU OpKernel, but without CUDA/HIP there are no GPU OpKernels at all.
92 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
93 
94 const std::pair<int, int> kMinGPUArch = {7, 0};
95 
96 class AutoMixedPrecisionTest : public GrapplerTest {
97  protected:
SetUp()98   void SetUp() override {
99     int num_gpus = GetNumAvailableGPUs();
100     // If GPUs are available, require that they all satisfy the min arch.
101     gpu_available_ = (num_gpus > 0);
102 #if GOOGLE_CUDA
103     gpu_available_ =
104         gpu_available_ && (num_gpus == GetNumAvailableGPUs(kMinGPUArch));
105 #else  // Here we force Tensorflow to use the virtual GFX906
106     gpu_available_ = false;
107 #endif
108     if (gpu_available_) {
109       virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 1));
110     } else {
111       DeviceProperties device_properties;
112       device_properties.set_type("GPU");
113 #if GOOGLE_CUDA
114       device_properties.mutable_environment()->insert({"architecture", "7"});
115       device_properties.mutable_environment()->insert({"cuda", "9010"});
116 #else
117       device_properties.mutable_environment()->insert(
118           {"architecture", "gfx906"});
119 #endif
120       virtual_cluster_.reset(
121           new VirtualCluster({{"/GPU:1", device_properties}}));
122     }
123     TF_CHECK_OK(virtual_cluster_->Provision());
124   }
125 
TearDown()126   void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
127 
AddSimpleNode(const string & name,const string & op,const std::vector<string> & inputs,GraphDef * graph) const128   NodeDef* AddSimpleNode(const string& name, const string& op,
129                          const std::vector<string>& inputs,
130                          GraphDef* graph) const {
131     std::vector<std::pair<string, AttrValue>> attributes;
132     if (op == "AddN" || op == "ShapeN") {
133       AttrValue num_inputs;
134       num_inputs.set_i(inputs.size());
135       attributes.emplace_back("N", num_inputs);
136     }
137     if (op == "ShapeN") {
138       AttrValue out_type;
139       out_type.set_type(DT_INT32);
140       attributes.emplace_back("out_type", out_type);
141     }
142     AttrValue type;
143     type.set_type(DT_FLOAT);
144     if (op == "Const" || op == "Placeholder" || op == "VariableV2" ||
145         op == "VarHandleOp" || op == "ReadVariableOp") {
146       attributes.emplace_back("dtype", type);
147     } else if (op == "SparseMatMul") {
148       attributes.emplace_back("Ta", type);
149       attributes.emplace_back("Tb", type);
150     } else if (op == "IdentityN") {
151       AttrValue type_list;
152       for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
153         type_list.mutable_list()->add_type(DT_FLOAT);
154       }
155       attributes.emplace_back("T", type_list);
156     } else if (op == "StackV2" || op == "StackPopV2") {
157       attributes.emplace_back("elem_type", type);
158     } else if (op == "Cast") {
159       attributes.emplace_back("SrcT", type);
160       attributes.emplace_back("DstT", type);
161     } else {
162       attributes.emplace_back("T", type);
163     }
164     return AddNode(name, op, inputs, attributes, graph);
165   }
166 
TestSimpleUnaryInferOp(double input_min,double input_max,double atol,double rtol,const std::function<Output (const tensorflow::Scope &,Output)> & test_op_factory)167   void TestSimpleUnaryInferOp(
168       double input_min, double input_max, double atol, double rtol,
169       const std::function<Output(const tensorflow::Scope&, Output)>&
170           test_op_factory) {
171     int size = 128;
172     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
173     Output eye = ops::Const(s.WithOpName("eye"),
174                             GenerateIdentityMatrix<DT_FLOAT>(size, size));
175     Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
176     Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, eye);
177     Output infer1 = test_op_factory(s.WithOpName("infer1"), allow1);
178     Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, eye);
179     Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
180     GrapplerItem item;
181     item.fetch = {"fetch1"};
182     TF_CHECK_OK(s.ToGraphDef(&item.graph));
183     auto input_tensor = GenerateRandomTensorInRange<DT_FLOAT>(
184         TensorShape({size, size}), input_min, input_max);
185     std::vector<std::pair<string, Tensor>> feed = {{"input", input_tensor}};
186     auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
187 
188     AutoMixedPrecision optimizer;
189     GraphDef output;
190     TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
191 
192     VLOG(1) << output.DebugString();
193 
194     GraphView output_view(&output);
195     EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(),
196               DT_FLOAT);
197     EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
198     EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
199     EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
200 
201     auto tensors = EvaluateNodes(output, item.fetch, feed);
202     EXPECT_EQ(tensors.size(), tensors_expected.size());
203     EXPECT_EQ(tensors.size(), item.fetch.size());
204     for (int i = 0; i < item.fetch.size(); ++i) {
205       test::ExpectClose(tensors_expected[i], tensors[i], atol, rtol);
206     }
207   }
208 
209   std::unique_ptr<Cluster> virtual_cluster_;
210   bool gpu_available_;
211 };
212 
TEST_F(AutoMixedPrecisionTest,NoOp)213 TEST_F(AutoMixedPrecisionTest, NoOp) {
214   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
215   Output input = ops::Const(s.WithOpName("input"), 1.234f, {32});
216   Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
217   Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
218   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
219   Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
220   Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
221 
222   GrapplerItem item;
223   item.fetch = {"fetch"};
224   TF_CHECK_OK(s.ToGraphDef(&item.graph));
225   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
226 
227   AutoMixedPrecision optimizer;
228   GraphDef output;
229   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
230 
231   VLOG(1) << output.DebugString();
232 
233   VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
234 
235   GraphView output_view(&output);
236   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
237   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
238   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
239   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
240   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
241 
242   auto tensors = EvaluateNodes(output, item.fetch);
243   EXPECT_EQ(tensors.size(), tensors_expected.size());
244   EXPECT_EQ(tensors.size(), item.fetch.size());
245   for (int i = 0; i < item.fetch.size(); ++i) {
246     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
247   }
248 }
249 
TEST_F(AutoMixedPrecisionTest,AlreadyFp16)250 TEST_F(AutoMixedPrecisionTest, AlreadyFp16) {
251   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
252   Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
253   Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_HALF);
254   Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
255   Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
256   Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
257   Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
258   Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
259 
260   GrapplerItem item;
261   item.fetch = {"fetch"};
262   TF_CHECK_OK(s.ToGraphDef(&item.graph));
263   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
264 
265   AutoMixedPrecision optimizer;
266   GraphDef output;
267   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
268   VLOG(1) << output.DebugString();
269 
270   VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
271   GraphView output_view(&output);
272   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
273   EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
274   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
275   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
276   EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_HALF);
277   EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
278   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
279 
280   auto tensors = EvaluateNodes(output, item.fetch);
281   EXPECT_EQ(tensors.size(), tensors_expected.size());
282   EXPECT_EQ(tensors.size(), item.fetch.size());
283   for (int i = 0; i < item.fetch.size(); ++i) {
284     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
285   }
286 }
287 
TEST_F(AutoMixedPrecisionTest,Simple)288 TEST_F(AutoMixedPrecisionTest, Simple) {
289   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
290   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
291   Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
292   Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
293   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
294   Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
295   Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
296   Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
297   Output infer2 = ops::Log(s.WithOpName("infer2"), clr3);
298   Output clr4 = ops::Relu(s.WithOpName("clr4"), infer2);
299   Output deny2 = ops::SparseMatMul(s.WithOpName("deny2"), clr4, clr4);
300   Output clr5 = ops::Relu(s.WithOpName("clr5"), deny2);
301   Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
302 
303   GrapplerItem item;
304   item.fetch = {"fetch"};
305   TF_CHECK_OK(s.ToGraphDef(&item.graph));
306   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
307 
308   AutoMixedPrecision optimizer;
309   GraphDef output;
310   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
311 
312   VLOG(1) << output.DebugString();
313 
314   GraphView output_view(&output);
315   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
316   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
317   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
318   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
319   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
320   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
321   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
322   EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
323   EXPECT_EQ(output_view.GetNode("infer2")->attr().at("T").type(), DT_FLOAT);
324   EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
325   EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Ta").type(), DT_FLOAT);
326   EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Tb").type(), DT_FLOAT);
327   EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
328 
329   auto tensors = EvaluateNodes(output, item.fetch);
330   EXPECT_EQ(tensors.size(), tensors_expected.size());
331   EXPECT_EQ(tensors.size(), item.fetch.size());
332   for (int i = 0; i < item.fetch.size(); ++i) {
333     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
334   }
335 }
336 
TEST_F(AutoMixedPrecisionTest,BidirectionalClearChain)337 TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) {
338   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
339   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
340   Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
341   Output clr2 = ops::Relu(s.WithOpName("clr2"), input);
342   Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
343   auto clr3 = ops::ShapeN(s.WithOpName("clr3"), {clr1, clr2});
344   Output clr4 = ops::Relu(s.WithOpName("clr4"), clr2);
345   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
346   Output fetch2 = ops::Identity(s.WithOpName("fetch2"), clr4);
347 
348   GrapplerItem item;
349   item.fetch = {"fetch1", "fetch2"};
350   TF_CHECK_OK(s.ToGraphDef(&item.graph));
351   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
352 
353   AutoMixedPrecision optimizer;
354   GraphDef output;
355   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
356 
357   VLOG(1) << output.DebugString();
358 
359   GraphView output_view(&output);
360   EXPECT_EQ(output.node_size(), item.graph.node_size() + 3);
361   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
362   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
363   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
364   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
365   EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
366   EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_HALF);
367 
368   auto tensors = EvaluateNodes(output, item.fetch);
369   EXPECT_EQ(tensors.size(), tensors_expected.size());
370   EXPECT_EQ(tensors.size(), item.fetch.size());
371   for (int i = 0; i < item.fetch.size(); ++i) {
372     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
373   }
374 }
375 
TEST_F(AutoMixedPrecisionTest,PreserveFetches)376 TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
377   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
378   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
379   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
380   Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
381   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
382   Output deny1 = ops::Exp(s.WithOpName("deny1"), infer1);
383   Output clr2 = ops::Relu(s.WithOpName("clr2"), deny1);
384   Output allow2 = ops::MatMul(s.WithOpName("allow2"), clr2, clr2);
385   Output clr3 = ops::Relu(s.WithOpName("clr3"), allow2);
386   Output deny2 = ops::Exp(s.WithOpName("deny2"), clr3);
387   Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
388 
389   GrapplerItem item;
390   item.fetch = {"allow1", "clr2", "clr3"};
391   TF_CHECK_OK(s.ToGraphDef(&item.graph));
392   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
393 
394   AutoMixedPrecision optimizer;
395   GraphDef output;
396   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
397 
398   VLOG(1) << output.DebugString();
399 
400   GraphView output_view(&output);
401   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
402   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
403   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
404   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
405   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
406   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
407   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
408   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
409   EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_FLOAT);
410   EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
411   EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
412 
413   auto tensors = EvaluateNodes(output, item.fetch);
414   EXPECT_EQ(tensors.size(), tensors_expected.size());
415   EXPECT_EQ(tensors.size(), item.fetch.size());
416   for (int i = 0; i < item.fetch.size(); ++i) {
417     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-3);
418   }
419 }
420 
TEST_F(AutoMixedPrecisionTest,PreserveCPUNodes)421 TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) {
422   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
423   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
424   Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
425   Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
426   Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
427   Output allow2 =
428       ops::MatMul(s.WithOpName("allow2").WithDevice(
429                       "/job:localhost/replica:0/task:0/device:CPU:0"),
430                   infer1, infer1);
431   Output clr2 = ops::Relu(s.WithOpName("clr2"), allow2);
432   Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
433 
434   GrapplerItem item;
435   item.fetch = {"fetch"};
436   TF_CHECK_OK(s.ToGraphDef(&item.graph));
437   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
438 
439   AutoMixedPrecision optimizer;
440   GraphDef output;
441   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
442 
443   VLOG(1) << output.DebugString();
444 
445   GraphView output_view(&output);
446   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
447   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
448   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
449   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
450   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
451   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_FLOAT);
452   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
453 
454   auto tensors = EvaluateNodes(output, item.fetch);
455   EXPECT_EQ(tensors.size(), tensors_expected.size());
456   EXPECT_EQ(tensors.size(), item.fetch.size());
457   for (int i = 0; i < item.fetch.size(); ++i) {
458     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
459   }
460 }
461 
TEST_F(AutoMixedPrecisionTest,PreserveIdentityAfterVariable)462 TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) {
463   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
464   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
465   Output var1 = ops::Variable(s.WithOpName("var1"), {32, 32}, DT_FLOAT);
466   Output clr1 = ops::Identity(s.WithOpName("clr1"), var1);
467   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, clr1);
468   Output input2 = ops::Const(s.WithOpName("input2"), 1.f / 32, {32, 32});
469   Output clr2 = ops::Identity(s.WithOpName("clr2"), input2);
470   Output allow2 = ops::MatMul(s.WithOpName("allow2"), input, clr2);
471   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
472   Output fetch2 = ops::Identity(s.WithOpName("fetch2"), allow2);
473 
474   GrapplerItem item;
475   item.fetch = {"fetch1", "fetch2"};
476   TF_CHECK_OK(s.ToGraphDef(&item.graph));
477   auto var1_tensor =
478       GenerateConstantTensor<DT_FLOAT>(TensorShape({32, 32}), 3.141593f);
479   std::vector<std::pair<string, Tensor>> feed = {{"var1", var1_tensor}};
480   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
481 
482   AutoMixedPrecision optimizer;
483   GraphDef output;
484   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
485 
486   VLOG(1) << output.DebugString();
487 
488   GraphView output_view(&output);
489   EXPECT_EQ(output.node_size(), item.graph.node_size() + 5);
490   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
491   EXPECT_EQ(output_view.GetNode("var1")->attr().at("dtype").type(), DT_FLOAT);
492   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
493   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
494   EXPECT_EQ(output_view.GetNode("input2")->attr().at("dtype").type(), DT_FLOAT);
495   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
496   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
497 
498   auto tensors = EvaluateNodes(output, item.fetch, feed);
499   EXPECT_EQ(tensors.size(), tensors_expected.size());
500   EXPECT_EQ(tensors.size(), item.fetch.size());
501   for (int i = 0; i < item.fetch.size(); ++i) {
502     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-3);
503   }
504 }
505 
TEST_F(AutoMixedPrecisionTest,FusedBatchNorm)506 TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
507   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
508   // Uses NHWC data format because non-GPU execution does not support NCHW.
509   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {8, 56, 56, 16});
510   Output weight = ops::Const(s.WithOpName("weight"), 2.f, {3, 3, 16, 16});
511   Output scale = ops::Const(s.WithOpName("scale"), 3.f, {16});
512   Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16});
513   Output mean = ops::Const(s.WithOpName("mean"), 5.f, {0});
514   Output variance = ops::Const(s.WithOpName("variance"), 6.f, {0});
515   Output allow1 =
516       ops::Conv2D(s.WithOpName("allow1"), input, weight, {1, 1, 1, 1}, "SAME",
517                   ops::Conv2D::DataFormat("NHWC"));
518   auto fbn1_op =
519       ops::FusedBatchNorm(s.WithOpName("fbn1"), allow1, scale, offset, mean,
520                           variance, ops::FusedBatchNorm::DataFormat("NHWC"));
521   Output fbn1 = fbn1_op.y;
522   Output fbn1_rs1 = fbn1_op.reserve_space_1;
523   Output fbn1_rs2 = fbn1_op.reserve_space_2;
524   Output bng1 = ops::FusedBatchNormGrad(
525                     s.WithOpName("bng1"), fbn1, allow1, scale, fbn1_rs1,
526                     fbn1_rs2, ops::FusedBatchNormGrad::DataFormat("NHWC"))
527                     .x_backprop;
528   Output infer1 = ops::Add(s.WithOpName("infer1"), fbn1, bng1);
529   Output allow2 =
530       ops::Conv2D(s.WithOpName("allow2"), infer1, weight, {1, 1, 1, 1}, "SAME",
531                   ops::Conv2D::DataFormat("NHWC"));
532   Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
533 
534   GrapplerItem item;
535   item.fetch = {"fetch"};
536   TF_CHECK_OK(s.ToGraphDef(&item.graph));
537   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
538 
539   AutoMixedPrecision optimizer;
540   GraphDef output;
541   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
542 
543   VLOG(1) << output.DebugString();
544 
545   GraphView output_view(&output);
546   EXPECT_EQ(output.node_size(), item.graph.node_size() + 3);
547   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
548   EXPECT_EQ(output_view.GetNode("fbn1")->op(), "FusedBatchNormV2");
549   EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("T").type(), DT_HALF);
550   EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("U").type(), DT_FLOAT);
551   EXPECT_EQ(output_view.GetNode("bng1")->op(), "FusedBatchNormGradV2");
552   EXPECT_EQ(output_view.GetNode("bng1")->attr().at("T").type(), DT_HALF);
553   EXPECT_EQ(output_view.GetNode("bng1")->attr().at("U").type(), DT_FLOAT);
554   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
555   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
556 
557   auto tensors = EvaluateNodes(output, item.fetch);
558   EXPECT_EQ(tensors.size(), tensors_expected.size());
559   EXPECT_EQ(tensors.size(), item.fetch.size());
560   for (int i = 0; i < item.fetch.size(); ++i) {
561     test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-2);
562   }
563 }
564 
TEST_F(AutoMixedPrecisionTest,RepeatedAndListTypeAttrs)565 TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
566   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
567   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
568   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
569   auto clr1_op = ops::IdentityN(s.WithOpName("clr1"), {allow1, allow1, allow1});
570   Output infer1 =
571       ops::AddN(s.WithOpName("infer1"),
572                 {clr1_op.output[0], clr1_op.output[1], clr1_op.output[2]});
573   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
574   Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
575 
576   GrapplerItem item;
577   item.fetch = {"fetch"};
578   TF_CHECK_OK(s.ToGraphDef(&item.graph));
579   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
580 
581   AutoMixedPrecision optimizer;
582   GraphDef output;
583   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
584 
585   VLOG(1) << output.DebugString();
586 
587   GraphView output_view(&output);
588   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
589   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
590   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
591   for (auto type : output_view.GetNode("clr1")->attr().at("T").list().type()) {
592     EXPECT_EQ(type, DT_HALF);
593   }
594   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
595   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
596 
597   auto tensors = EvaluateNodes(output, item.fetch);
598   EXPECT_EQ(tensors.size(), tensors_expected.size());
599   EXPECT_EQ(tensors.size(), item.fetch.size());
600   for (int i = 0; i < item.fetch.size(); ++i) {
601     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
602   }
603 }
604 
TEST_F(AutoMixedPrecisionTest,ExistingCast)605 TEST_F(AutoMixedPrecisionTest, ExistingCast) {
606   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
607   Output input = ops::Const(s.WithOpName("input"), true, {32, 32});
608   Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_FLOAT);
609   Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
610   Output fetch = ops::Identity(s.WithOpName("fetch"), allow1);
611 
612   GrapplerItem item;
613   item.fetch = {"fetch"};
614   TF_CHECK_OK(s.ToGraphDef(&item.graph));
615   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
616 
617   AutoMixedPrecision optimizer;
618   GraphDef output;
619   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
620 
621   VLOG(1) << output.DebugString();
622 
623   GraphView output_view(&output);
624   EXPECT_EQ(output.node_size(), item.graph.node_size() + 1);
625   EXPECT_EQ(output_view.GetNode("cst1")->attr().at("SrcT").type(), DT_BOOL);
626   EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
627   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
628 
629   auto tensors = EvaluateNodes(output, item.fetch);
630   EXPECT_EQ(tensors.size(), tensors_expected.size());
631   EXPECT_EQ(tensors.size(), item.fetch.size());
632   for (int i = 0; i < item.fetch.size(); ++i) {
633     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
634   }
635 }
636 
TEST_F(AutoMixedPrecisionTest,RecurrentEdgeColorMismatch)637 TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) {
638   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
639   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
640   Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
641   Output ent1 =
642       ops::internal::Enter(s.WithOpName("ent1"), deny1, "loop1").output;
643   // Note that the second input is later replaced with "nxt1".
644   Output mrg1 = ops::Merge(s.WithOpName("mrg1"), {ent1, ent1}).output;
645   // For simplicity, the loop condition is constant false.
646   Output con1 = ops::Const(s.WithOpName("con1"), false, {});
647   Output lpc1 = ops::LoopCond(s.WithOpName("lpc1"), con1).output;
648   auto swt1 = ops::Switch(s.WithOpName("swt1"), mrg1, lpc1);
649   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), swt1.output_true);
650   Output allow1 = ops::MatMul(s.WithOpName("allow1"), infer1, infer1);
651   Output nxt1 = ops::NextIteration(s.WithOpName("nxt1"), allow1);
652   Output ext1 = ops::internal::Exit(s.WithOpName("ext1"), swt1.output_false);
653   Output fetch = ops::Identity(s.WithOpName("fetch"), ext1);
654   // Add a second merge node from the same NextIteration node. This case arises
655   // during graph optimization of some models.
656   auto mrg2 = ops::Merge(s.WithOpName("mrg2"), {ent1, nxt1});
657 
658   GrapplerItem item;
659   item.fetch = {"fetch"};
660   TF_CHECK_OK(s.ToGraphDef(&item.graph));
661   NodeMap node_map_original(&item.graph);
662   auto merge_node = node_map_original.GetNode("mrg1");
663   // Modify the graph to create a loop.
664   merge_node->set_input(1, "nxt1");
665   // Add a control edge to ensure the loop condition is inside the frame.
666   auto const_node = node_map_original.GetNode("con1");
667   const_node->add_input("^mrg1");
668   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
669 
670   AutoMixedPrecision optimizer;
671   GraphDef output;
672   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
673 
674   VLOG(1) << output.DebugString();
675 
676   GraphView output_view(&output);
677   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
678   // Note that mrg1 gets painted deny because it is between deny1 and infer1.
679   // This forces nxt1 and mrg2 to be painted deny as well (they would otherwise
680   // be painted allow because they are clear and have a direct path to allow1).
681   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
682   EXPECT_EQ(output_view.GetNode("ent1")->attr().at("T").type(), DT_FLOAT);
683   EXPECT_EQ(output_view.GetNode("mrg1")->attr().at("T").type(), DT_FLOAT);
684   EXPECT_EQ(output_view.GetNode("swt1")->attr().at("T").type(), DT_FLOAT);
685   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
686   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
687   EXPECT_EQ(output_view.GetNode("nxt1")->attr().at("T").type(), DT_FLOAT);
688   EXPECT_EQ(output_view.GetNode("ext1")->attr().at("T").type(), DT_FLOAT);
689   EXPECT_EQ(output_view.GetNode("mrg2")->attr().at("T").type(), DT_FLOAT);
690 
691   auto tensors = EvaluateNodes(output, item.fetch);
692   EXPECT_EQ(tensors.size(), tensors_expected.size());
693   EXPECT_EQ(tensors.size(), item.fetch.size());
694   for (int i = 0; i < item.fetch.size(); ++i) {
695     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
696   }
697 }
698 
TEST_F(AutoMixedPrecisionTest,TensorListSetGet)699 TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
700   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
701   tensorflow::Input shape = {32, 32};
702   auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
703   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
704   Output idx1 = ops::Const(s.WithOpName("idx1"), 1);
705   Output idx2 = ops::Const(s.WithOpName("idx2"), 2);
706   Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
707   auto tl1w1 =
708       ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
709   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
710   auto tl1w2 =
711       ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
712   // Ensure that TensorListResize doesn't cause any problems.
713   Output tl1rs =
714       ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
715   Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
716                                         shape, DT_FLOAT)
717                      .item;
718   Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
719   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
720   auto tl1w3 =
721       ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
722   Output tl1r2 =
723       ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
724                              shape, DT_FLOAT)
725           .item;
726   auto tl2 = ops::TensorListReserve(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
727   auto tl2w1 =
728       ops::TensorListSetItem(s.WithOpName("tl2w1"), tl2.handle, idx1, input);
729   Output tl2r1 =
730       ops::TensorListGetItem(s.WithOpName("tl2r1"), tl2w1.output_handle, idx1,
731                              shape, DT_FLOAT)
732           .item;
733   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
734   Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
735 
736   GrapplerItem item;
737   item.fetch = {"fetch1", "fetch2"};
738   TF_CHECK_OK(s.ToGraphDef(&item.graph));
739   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
740 
741   AutoMixedPrecision optimizer;
742   GraphDef output;
743   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
744 
745   VLOG(1) << output.DebugString();
746 
747   GraphView output_view(&output);
748   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
749   const char* type_key = "element_dtype";
750   EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
751   EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
752   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
753   EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
754   EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
755   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
756   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
757   EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
758   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
759   EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
760   EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
761 
762   auto tensors = EvaluateNodes(output, item.fetch);
763   EXPECT_EQ(tensors.size(), tensors_expected.size());
764   EXPECT_EQ(tensors.size(), item.fetch.size());
765   for (int i = 0; i < item.fetch.size(); ++i) {
766     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
767   }
768 }
769 
TEST_F(AutoMixedPrecisionTest,TensorListPushPop)770 TEST_F(AutoMixedPrecisionTest, TensorListPushPop) {
771   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
772   tensorflow::Input shape = {32, 32};
773   auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
774   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
775   auto tl1w1 =
776       ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, input);
777   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
778   auto tl1w2 = ops::TensorListPushBack(s.WithOpName("tl1w2"),
779                                        tl1w1.output_handle, allow1);
780   Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"),
781                                         tl1w2.output_handle, shape, DT_FLOAT)
782                      .tensor;
783   Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
784   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
785   auto tl1w3 =
786       ops::TensorListPushBack(s.WithOpName("tl1w3"), tl1.handle, allow2);
787   Output tl1r2 = ops::TensorListPopBack(s.WithOpName("tl1r2"),
788                                         tl1w3.output_handle, shape, DT_FLOAT)
789                      .tensor;
790   auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
791   auto tl2w1 =
792       ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, input);
793   Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
794                                         tl2w1.output_handle, shape, DT_FLOAT)
795                      .tensor;
796   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
797   Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
798 
799   GrapplerItem item;
800   item.fetch = {"fetch1", "fetch2"};
801   TF_CHECK_OK(s.ToGraphDef(&item.graph));
802   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
803 
804   AutoMixedPrecision optimizer;
805   GraphDef output;
806   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
807 
808   VLOG(1) << output.DebugString();
809 
810   GraphView output_view(&output);
811   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
812   const char* type_key = "element_dtype";
813   EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
814   EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
815   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
816   EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
817   EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
818   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
819   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
820   EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
821   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
822   EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
823   EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
824 
825   auto tensors = EvaluateNodes(output, item.fetch);
826   EXPECT_EQ(tensors.size(), tensors_expected.size());
827   EXPECT_EQ(tensors.size(), item.fetch.size());
828   for (int i = 0; i < item.fetch.size(); ++i) {
829     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
830   }
831 }
832 
TEST_F(AutoMixedPrecisionTest,TensorListFromTensor)833 TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) {
834   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
835   tensorflow::Input shape = {32};
836   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
837   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
838   auto tl1 = ops::TensorListFromTensor(s.WithOpName("tl1"), allow1, shape);
839   Output tl1r1 = ops::TensorListStack(s.WithOpName("tl1r1"), tl1.output_handle,
840                                       shape, DT_FLOAT)
841                      .tensor;
842   Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
843   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
844   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
845 
846   // This tests that a allow-painted object node (tl2) will force an unpainted
847   // client node (tl2w1) to be painted allow as well. (Without the force, tl2w1
848   // would remain unpainted, producing an invalid graph).
849   auto tl2 = ops::TensorListFromTensor(s.WithOpName("tl2"), allow1, shape);
850   auto tl2w1 =
851       ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.output_handle, input);
852 
853   GrapplerItem item;
854   item.fetch = {"fetch1"};
855   TF_CHECK_OK(s.ToGraphDef(&item.graph));
856   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
857 
858   AutoMixedPrecision optimizer;
859   GraphDef output;
860   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
861 
862   VLOG(1) << output.DebugString();
863 
864   GraphView output_view(&output);
865   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
866   const char* type_key = "element_dtype";
867   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
868   EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
869   EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
870   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
871   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
872   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
873   EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
874 
875   auto tensors = EvaluateNodes(output, item.fetch);
876   EXPECT_EQ(tensors.size(), tensors_expected.size());
877   EXPECT_EQ(tensors.size(), item.fetch.size());
878   for (int i = 0; i < item.fetch.size(); ++i) {
879     test::ExpectClose(tensors_expected[i], tensors[i], -1, 2e-4);
880   }
881 }
882 
TEST_F(AutoMixedPrecisionTest,TensorListPushBackBatchAndConcatLists)883 TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
884   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
885   tensorflow::Input shape = {32, 32};
886   auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
887   auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
888   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
889   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
890   Output tl1_tl2 =
891       ops::Stack(s.WithOpName("tl1_tl2"), {tl1.handle, tl2.handle});
892   Output allow1_allow1 =
893       ops::Stack(s.WithOpName("allow1_allow1"), {allow1, allow1});
894   auto tl12w1 = ops::TensorListPushBackBatch(s.WithOpName("tl12w1"), tl1_tl2,
895                                              allow1_allow1);
896   OutputList tl12w1_outputs =
897       ops::Split(s.WithOpName("tl12w1_outputs"), 0, tl12w1.output_handles, 2)
898           .output;
899   Output scalar_shape = ops::Const(s.WithOpName("scalar_shape"), 0, {0});
900   Output tl12w1_output0 = ops::Reshape(s.WithOpName("tl12w1_output0"),
901                                        tl12w1_outputs[0], scalar_shape);
902   Output tl12w1_output1 = ops::Reshape(s.WithOpName("tl12w1_output1"),
903                                        tl12w1_outputs[1], scalar_shape);
904   Output tl3 = ops::TensorListConcatLists(s.WithOpName("tl3"), tl12w1_output0,
905                                           tl12w1_output1, DT_FLOAT);
906   Output tl3r1 =
907       ops::TensorListPopBack(s.WithOpName("tl3r1"), tl3, shape, DT_FLOAT)
908           .tensor;
909   Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl3r1);
910   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
911   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
912 
913   GrapplerItem item;
914   item.fetch = {"fetch1"};
915   TF_CHECK_OK(s.ToGraphDef(&item.graph));
916   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
917 
918   AutoMixedPrecision optimizer;
919   GraphDef output;
920   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
921 
922   VLOG(1) << output.DebugString();
923 
924   GraphView output_view(&output);
925   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
926   const char* type_key = "element_dtype";
927   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
928   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
929   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
930   EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
931   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
932   EXPECT_EQ(output_view.GetNode("tl3")->attr().at(type_key).type(), DT_HALF);
933   EXPECT_EQ(output_view.GetNode("tl3r1")->attr().at(type_key).type(), DT_HALF);
934 
935   auto tensors = EvaluateNodes(output, item.fetch);
936   EXPECT_EQ(tensors.size(), tensors_expected.size());
937   EXPECT_EQ(tensors.size(), item.fetch.size());
938   for (int i = 0; i < item.fetch.size(); ++i) {
939     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
940   }
941 }
942 
TEST_F(AutoMixedPrecisionTest,TensorListThroughFunction)943 TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
944   // This test passes a tensor list handle through a function with its own
945   // Tensor List ops inside to test that the types are not changed to a
946   // conflicting state.
947   // A separate Tensor List cluster is added to test that it is still changed to
948   // DT_HALF.
949   FunctionDefLibrary function_lib;
950   const Tensor kShape = test::AsTensor<int32>({32, 32});
951   FunctionDef func1 = FunctionDefHelper::Define(
952       "Func1", {"ihandle: variant", "x: float"},
953       {"ohandle: variant", "y: float"}, {},
954       {
955           {{"tl1w1_handle"},
956            "TensorListPushBack",
957            {"ihandle", "x"},
958            {{"element_dtype", DT_FLOAT}}},
959           {{"shape"}, "Const", {}, {{"value", kShape}, {"dtype", DT_INT32}}},
960           {{"tl1r1_handle", "tl1r1_data"},
961            "TensorListPopBack",
962            {"tl1w1_handle", "shape"},
963            {{"element_dtype", DT_FLOAT}}},
964           {{"ohandle"}, "Identity", {"tl1r1_handle"}, {{"T", DT_VARIANT}}},
965           {{"y"}, "Identity", {"tl1r1_data"}, {{"T", DT_FLOAT}}},
966       });
967   function_lib.add_function()->Swap(&func1);
968 
969   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
970   TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib));
971   tensorflow::Input shape = {32, 32};
972   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
973   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
974   Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
975   auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
976   auto tl1w1 =
977       ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, infer1);
978   auto _infer1 = tensorflow::ops::AsNodeOut(s, infer1);
979   auto _tl1w1_handle = tensorflow::ops::AsNodeOut(s, tl1w1.output_handle);
980   auto builder =
981       tensorflow::NodeBuilder("Func1", "Func1", s.graph()->op_registry());
982   tensorflow::Node* func1_op;
983   TF_CHECK_OK(builder.Input(_tl1w1_handle)
984                   .Input(_infer1)
985                   .Finalize(s.graph(), &func1_op));
986   Output func1_handle(func1_op, 0);
987   Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"), func1_handle,
988                                         shape, DT_FLOAT)
989                      .tensor;
990   auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
991   auto tl2w1 =
992       ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, infer1);
993   Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
994                                         tl2w1.output_handle, shape, DT_FLOAT)
995                      .tensor;
996   Output allow2 = ops::MatMul(s.WithOpName("allow2"), tl1r1, tl2r1);
997   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
998 
999   GrapplerItem item;
1000   item.fetch = {"fetch1"};
1001   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1002   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1003 
1004   AutoMixedPrecision optimizer;
1005   GraphDef output;
1006   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1007 
1008   VLOG(1) << output.DebugString();
1009 
1010   GraphView output_view(&output);
1011   const char* type_key = "element_dtype";
1012   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
1013   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
1014   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
1015   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
1016   EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
1017   EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_HALF);
1018 
1019   auto tensors = EvaluateNodes(output, item.fetch);
1020   EXPECT_EQ(tensors.size(), tensors_expected.size());
1021   EXPECT_EQ(tensors.size(), item.fetch.size());
1022   for (int i = 0; i < item.fetch.size(); ++i) {
1023     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
1024   }
1025 }
1026 
GetCudaVersion(const Cluster & cluster)1027 int GetCudaVersion(const Cluster& cluster) {
1028   auto devices = cluster.GetDevices();
1029   for (const auto& device : devices) {
1030     const DeviceProperties& device_properties = device.second;
1031     if (device_properties.type() == "GPU") {
1032       const auto& device_env = device_properties.environment();
1033       auto it = device_env.find("cuda");
1034       if (it != device_env.end()) {
1035         string cuda_version_str = it->second;
1036         return std::stoi(cuda_version_str);
1037       }
1038     }
1039   }
1040   return 0;
1041 }
1042 
IsSupportedGPU(const Cluster & cluster)1043 bool IsSupportedGPU(const Cluster& cluster) {
1044 #ifdef GOOGLE_CUDA
1045   return GetCudaVersion(cluster) >= 9010;
1046 #else
1047   return true;
1048 #endif
1049 }
1050 
TEST_F(AutoMixedPrecisionTest,BatchMatMul)1051 TEST_F(AutoMixedPrecisionTest, BatchMatMul) {
1052   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1053   Output input = ops::Const(s.WithOpName("input"), 1.f / 33, {64, 32, 32});
1054   Output allow1 = ops::BatchMatMul(s.WithOpName("allow1"), input, input);
1055   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
1056 
1057   GrapplerItem item;
1058   item.fetch = {"fetch1"};
1059   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1060   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1061 
1062   AutoMixedPrecision optimizer;
1063   GraphDef output;
1064   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1065 
1066   VLOG(1) << output.DebugString();
1067 
1068   GraphView output_view(&output);
1069   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1070   if (IsSupportedGPU(*virtual_cluster_.get())) {
1071     EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1072     EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
1073   } else {
1074     EXPECT_EQ(output.node_size(), item.graph.node_size());
1075     EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
1076   }
1077 
1078   auto tensors = EvaluateNodes(output, item.fetch);
1079   EXPECT_EQ(tensors.size(), tensors_expected.size());
1080   EXPECT_EQ(tensors.size(), item.fetch.size());
1081   for (int i = 0; i < item.fetch.size(); ++i) {
1082     test::ExpectClose(tensors_expected[i], tensors[i], -1, 3.0e-3);
1083   }
1084 }
1085 
TEST_F(AutoMixedPrecisionTest,EluOp)1086 TEST_F(AutoMixedPrecisionTest, EluOp) {
1087   TestSimpleUnaryInferOp(
1088       -5, 5, 1.0e-3, 1.0e-3,
1089       [](const tensorflow::Scope& scope, Output input) -> Output {
1090         return ops::Elu(scope, input);
1091       });
1092 }
1093 
TEST_F(AutoMixedPrecisionTest,ErfOp)1094 TEST_F(AutoMixedPrecisionTest, ErfOp) {
1095   TestSimpleUnaryInferOp(
1096       -5, 5, 1.0e-3, -1,
1097       [](const tensorflow::Scope& scope, Output input) -> Output {
1098         return ops::Erf(scope, input);
1099       });
1100 }
1101 
TEST_F(AutoMixedPrecisionTest,ErfcOp)1102 TEST_F(AutoMixedPrecisionTest, ErfcOp) {
1103   TestSimpleUnaryInferOp(
1104       -5, 5, 1.0e-3, -1,
1105       [](const tensorflow::Scope& scope, Output input) -> Output {
1106         return ops::Erfc(scope, input);
1107       });
1108 }
1109 
TEST_F(AutoMixedPrecisionTest,InvOp)1110 TEST_F(AutoMixedPrecisionTest, InvOp) {
1111   TestSimpleUnaryInferOp(
1112       0.01, 10, -1, 1.0e-3,
1113       [](const tensorflow::Scope& scope, Output input) -> Output {
1114         return ops::Inv(scope, input);
1115       });
1116 }
1117 
TEST_F(AutoMixedPrecisionTest,LogOp)1118 TEST_F(AutoMixedPrecisionTest, LogOp) {
1119   TestSimpleUnaryInferOp(
1120       0.01, 10, 1.0e-3, 2.0e-3,
1121       [](const tensorflow::Scope& scope, Output input) -> Output {
1122         return ops::Log(scope, input);
1123       });
1124 }
1125 
TEST_F(AutoMixedPrecisionTest,Log1pOp)1126 TEST_F(AutoMixedPrecisionTest, Log1pOp) {
1127   TestSimpleUnaryInferOp(
1128       -0.99, 9, 1.0e-3, 5.0e-3,
1129       [](const tensorflow::Scope& scope, Output input) -> Output {
1130         return ops::Log1p(scope, input);
1131       });
1132 }
1133 
TEST_F(AutoMixedPrecisionTest,LogSoftmaxOp)1134 TEST_F(AutoMixedPrecisionTest, LogSoftmaxOp) {
1135   TestSimpleUnaryInferOp(
1136       -8, 8, -1, 1.0e-2,
1137       [](const tensorflow::Scope& scope, Output input) -> Output {
1138         return ops::LogSoftmax(scope, input);
1139       });
1140 }
1141 
TEST_F(AutoMixedPrecisionTest,ReciprocalOp)1142 TEST_F(AutoMixedPrecisionTest, ReciprocalOp) {
1143   TestSimpleUnaryInferOp(
1144       0.01, 10, -1, 1.0e-3,
1145       [](const tensorflow::Scope& scope, Output input) -> Output {
1146         return ops::Reciprocal(scope, input);
1147       });
1148 }
1149 
TEST_F(AutoMixedPrecisionTest,SigmoidOp)1150 TEST_F(AutoMixedPrecisionTest, SigmoidOp) {
1151   TestSimpleUnaryInferOp(
1152       -5, 5, 1.0e-3, -1,
1153       [](const tensorflow::Scope& scope, Output input) -> Output {
1154         return ops::Sigmoid(scope, input);
1155       });
1156 }
1157 
TEST_F(AutoMixedPrecisionTest,SoftmaxOp)1158 TEST_F(AutoMixedPrecisionTest, SoftmaxOp) {
1159   TestSimpleUnaryInferOp(
1160       -8, 8, 2.0e-3, -1,
1161       [](const tensorflow::Scope& scope, Output input) -> Output {
1162         return ops::Softmax(scope, input);
1163       });
1164 }
1165 
TEST_F(AutoMixedPrecisionTest,SoftplusOp)1166 TEST_F(AutoMixedPrecisionTest, SoftplusOp) {
1167   TestSimpleUnaryInferOp(
1168       -5, 5, 1.0e-3, 1.0e-3,
1169       [](const tensorflow::Scope& scope, Output input) -> Output {
1170         return ops::Softplus(scope, input);
1171       });
1172 }
1173 
TEST_F(AutoMixedPrecisionTest,SqrtOp)1174 TEST_F(AutoMixedPrecisionTest, SqrtOp) {
1175   TestSimpleUnaryInferOp(
1176       0, 10, 1.0e-3, 1.0e-3,
1177       [](const tensorflow::Scope& scope, Output input) -> Output {
1178         return ops::Sqrt(scope, input);
1179       });
1180 }
1181 
TEST_F(AutoMixedPrecisionTest,TanhOp)1182 TEST_F(AutoMixedPrecisionTest, TanhOp) {
1183   TestSimpleUnaryInferOp(
1184       -5, 5, 1.0e-3, -1,
1185       [](const tensorflow::Scope& scope, Output input) -> Output {
1186         return ops::Tanh(scope, input);
1187       });
1188 }
1189 
1190 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1191 
1192 #if INTEL_MKL
1193 
1194 class AutoMixedPrecisionMklTest : public GrapplerTest {
1195  protected:
SetUp()1196   void SetUp() override {
1197     virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0));
1198     TF_CHECK_OK(virtual_cluster_->Provision());
1199   }
TearDown()1200   void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
1201 
1202   std::unique_ptr<Cluster> virtual_cluster_;
1203 };
1204 
TEST_F(AutoMixedPrecisionMklTest,AlreadyBf16)1205 TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) {
1206   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1207       "/job:localhost/replica:0/task:0/device:CPU:0");
1208   Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
1209   Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_BFLOAT16);
1210   Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
1211   Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
1212   Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
1213   Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
1214   Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
1215 
1216   GrapplerItem item;
1217   item.fetch = {"fetch"};
1218   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1219   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1220 
1221   AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
1222   GraphDef output;
1223   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1224   VLOG(1) << output.DebugString();
1225 
1226   VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
1227   GraphView output_view(&output);
1228   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1229   EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_BFLOAT16);
1230   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1231   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_BFLOAT16);
1232   EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_BFLOAT16);
1233   EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
1234   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
1235 
1236   auto tensors = EvaluateNodes(output, item.fetch);
1237   EXPECT_EQ(tensors.size(), tensors_expected.size());
1238   EXPECT_EQ(tensors.size(), item.fetch.size());
1239   for (int i = 0; i < item.fetch.size(); ++i) {
1240     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
1241   }
1242 }
1243 
TEST_F(AutoMixedPrecisionMklTest,Simple)1244 TEST_F(AutoMixedPrecisionMklTest, Simple) {
1245   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1246       "/job:localhost/replica:0/task:0/device:CPU:0");
1247   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1248   Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
1249   Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
1250   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
1251   Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
1252   Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
1253   Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
1254   Output deny2 = ops::Log(s.WithOpName("deny2"), clr3);
1255   Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
1256   Output deny3 = ops::SparseMatMul(s.WithOpName("deny3"), clr4, clr4);
1257   Output clr5 = ops::Relu(s.WithOpName("clr5"), deny3);
1258   Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
1259 
1260   GrapplerItem item;
1261   item.fetch = {"fetch"};
1262   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1263   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1264 
1265   AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
1266   GraphDef output;
1267   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1268 
1269   VLOG(1) << output.DebugString();
1270 
1271   GraphView output_view(&output);
1272   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1273   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1274   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
1275   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
1276   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
1277   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_BFLOAT16);
1278   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1279   EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_BFLOAT16);
1280   EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
1281   EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
1282   EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Ta").type(), DT_FLOAT);
1283   EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Tb").type(), DT_FLOAT);
1284   EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
1285 
1286   auto tensors = EvaluateNodes(output, item.fetch);
1287   EXPECT_EQ(tensors.size(), tensors_expected.size());
1288   EXPECT_EQ(tensors.size(), item.fetch.size());
1289   for (int i = 0; i < item.fetch.size(); ++i) {
1290     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
1291   }
1292 }
1293 
TEST_F(AutoMixedPrecisionMklTest,TensorListSetGet)1294 TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
1295   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(
1296       "/job:localhost/replica:0/task:0/device:CPU:0");
1297   tensorflow::Input shape = {32, 32};
1298   auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
1299   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1300   Output idx1 = ops::Const(s.WithOpName("idx1"), 1);
1301   Output idx2 = ops::Const(s.WithOpName("idx2"), 2);
1302   Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
1303   auto tl1w1 =
1304       ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
1305   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
1306   auto tl1w2 =
1307       ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
1308   // Ensure that TensorListResize doesn't cause any problems.
1309   Output tl1rs =
1310       ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
1311   Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
1312                                         shape, DT_FLOAT)
1313                      .item;
1314   Output infer1 = ops::Mul(s.WithOpName("infer1"), tl1r1, tl1r1);
1315   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
1316   auto tl1w3 =
1317       ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
1318   Output tl1r2 =
1319       ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
1320                              shape, DT_FLOAT)
1321           .item;
1322   auto tl2 = ops::TensorListReserve(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
1323   auto tl2w1 =
1324       ops::TensorListSetItem(s.WithOpName("tl2w1"), tl2.handle, idx1, input);
1325   Output tl2r1 =
1326       ops::TensorListGetItem(s.WithOpName("tl2r1"), tl2w1.output_handle, idx1,
1327                              shape, DT_FLOAT)
1328           .item;
1329   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
1330   Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
1331 
1332   GrapplerItem item;
1333   item.fetch = {"fetch1", "fetch2"};
1334   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1335   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1336 
1337   AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
1338   GraphDef output;
1339   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1340 
1341   VLOG(1) << output.DebugString();
1342 
1343   GraphView output_view(&output);
1344   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1345   const char* type_key = "element_dtype";
1346   EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(),
1347             DT_BFLOAT16);
1348   EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(),
1349             DT_BFLOAT16);
1350   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1351   EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(),
1352             DT_BFLOAT16);
1353   EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(),
1354             DT_BFLOAT16);
1355   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_BFLOAT16);
1356   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_BFLOAT16);
1357   EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(),
1358             DT_BFLOAT16);
1359   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
1360   EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
1361   EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
1362 
1363   auto tensors = EvaluateNodes(output, item.fetch);
1364   EXPECT_EQ(tensors.size(), tensors_expected.size());
1365   EXPECT_EQ(tensors.size(), item.fetch.size());
1366   for (int i = 0; i < item.fetch.size(); ++i) {
1367     test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-2);
1368   }
1369 }
1370 
1371 #endif  // INTEL_MKL
1372 
1373 }  // namespace
1374 }  // namespace grappler
1375 }  // namespace tensorflow
1376 
1377 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || INTEL_MKL
1378