• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
17 
18 #include <regex>  // NOLINT
19 
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "tensorflow/cc/framework/ops.h"
23 #include "tensorflow/cc/framework/scope.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
26 #include "tensorflow/compiler/tf2tensorrt/utils/trt_testutils.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/device_set.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/grappler/clusters/cluster.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/protobuf/config.pb.h"  // NOLINT
36 #include "tensorflow/core/public/session.h"
37 
38 #if GOOGLE_CUDA && GOOGLE_TENSORRT
39 
40 namespace tensorflow {
41 namespace tensorrt {
42 namespace convert {
43 
44 class FakeCluster : public grappler::Cluster {
45  public:
FakeCluster()46   FakeCluster() : Cluster(0) {}
47 
SetDeviceSet(const DeviceSet * device_set)48   void SetDeviceSet(const DeviceSet* device_set) { device_set_ = device_set; }
49 
GetDeviceSet() const50   const DeviceSet* GetDeviceSet() const override { return device_set_; }
51 
type() const52   string type() const override { return ""; }
Provision()53   Status Provision() override { return Status::OK(); }
Initialize(const grappler::GrapplerItem & item)54   Status Initialize(const grappler::GrapplerItem& item) override {
55     return Status::OK();
56   }
Run(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & feed,const std::vector<string> & fetch,RunMetadata * metadata)57   Status Run(const GraphDef& graph_def,
58              const std::vector<std::pair<string, Tensor>>& feed,
59              const std::vector<string>& fetch, RunMetadata* metadata) override {
60     return Status::OK();
61   }
62 
63  private:
64   const DeviceSet* device_set_ = nullptr;
65 };
66 
TEST(GetDeviceAndAllocatorTest,GetDeviceAndAllocator)67 TEST(GetDeviceAndAllocatorTest, GetDeviceAndAllocator) {
68   TRTOptimizationPass::ConversionParams params;
69   EngineInfo engine_info;
70   {
71     // cluster is not set, and no gpu device is available.
72     auto result = GetDeviceAndAllocator(nullptr, engine_info);
73     EXPECT_EQ(-1, result.first);
74     EXPECT_EQ(nullptr, result.second);
75   }
76 
77   // Create a session with two (virtual) gpu device.
78   SessionOptions options;
79   ConfigProto* config = &options.config;
80   GPUOptions* gpu_options = config->mutable_gpu_options();
81   auto virtual_devices =
82       gpu_options->mutable_experimental()->add_virtual_devices();
83   virtual_devices->add_memory_limit_mb(200);
84   virtual_devices->add_memory_limit_mb(200);
85   std::unique_ptr<Session> session(NewSession(options));
86 
87   {
88     // cluster is not set, should find and return first gpu id and
89     // corresponding allocator.
90     auto result = GetDeviceAndAllocator(nullptr, engine_info);
91     EXPECT_EQ(0, result.first);
92     EXPECT_NE(nullptr, result.second);
93     EXPECT_EQ("GPU_0_bfc", result.second->Name());
94   }
95 
96   FakeCluster cluster;
97   {
98     // params.cluster->GetDeviceSet() returns null, should find and return first
99     // gpu id and corresponding allocator.
100     auto result = GetDeviceAndAllocator(&cluster, engine_info);
101     EXPECT_EQ(0, result.first);
102     EXPECT_NE(nullptr, result.second);
103     EXPECT_EQ("GPU_0_bfc", result.second->Name());
104   }
105 
106   // Build the DeviceSet.
107   DeviceSet device_set;
108   const DeviceMgr* device_mgr = nullptr;
109   TF_ASSERT_OK(session->LocalDeviceManager(&device_mgr));
110   for (auto d : device_mgr->ListDevices()) {
111     device_set.AddDevice(d);
112   }
113   cluster.SetDeviceSet(&device_set);
114   {
115     // engine_info.device is not set, should find and return first gpu id and
116     // corresponding allocator.
117     auto result = GetDeviceAndAllocator(&cluster, engine_info);
118     EXPECT_EQ(0, result.first);
119     EXPECT_NE(nullptr, result.second);
120     EXPECT_EQ("GPU_0_bfc", result.second->Name());
121   }
122 
123   engine_info.device = "/GPU:1";
124   {
125     // Set to use second device.
126     auto result = GetDeviceAndAllocator(&cluster, engine_info);
127     EXPECT_EQ(0, result.first);
128     EXPECT_NE(nullptr, result.second);
129     EXPECT_EQ("GPU_1_bfc", result.second->Name());
130   }
131 
132   engine_info.device = "/GPU:3";
133   {
134     // Set to use nonexistent device.
135     auto result = GetDeviceAndAllocator(&cluster, engine_info);
136     EXPECT_EQ(-1, result.first);
137     EXPECT_EQ(nullptr, result.second);
138   }
139 }
140 
141 class ConvertGraphTest : public ::testing::Test {
142  public:
RunConvertGraph(Scope s,GraphDef * output_graph_def,int maximum_batch_size=1000)143   Status RunConvertGraph(Scope s, GraphDef* output_graph_def,
144                          int maximum_batch_size = 1000) {
145     // Create GraphProperties.
146     grappler::GrapplerItem item;
147     TF_EXPECT_OK(s.ToGraphDef(&item.graph));
148     grappler::GraphProperties graph_properties(item);
149     TF_EXPECT_OK(graph_properties.InferStatically(true));
150 
151     // Construct ConversionParams.
152     const std::vector<string> input_output_names{"output"};
153     TRTOptimizationPass::ConversionParams params;
154     params.max_batch_size = maximum_batch_size;
155     params.max_workspace_size_bytes = 8 << 20;
156     params.minimum_segment_size = 1;
157     params.use_calibration = false;
158     params.trt_logger_name = "DefaultLogger";
159     return ConvertGraph(params, item, input_output_names, nullptr,
160                         output_graph_def);
161   }
162 };
163 
TEST_F(ConvertGraphTest,DirectlyConnectedEngines)164 TEST_F(ConvertGraphTest, DirectlyConnectedEngines) {
165   // Create the graph. There will be two TRTEngineOps after the conversion, and
166   // the upstream TRTEngineOp will have two output connections from the same
167   // node:port inside the op to the downstream TRTEngineOp. Then, if it adds the
168   // downstream TRTEngineOp first, when adding the upstream op it'll need to
169   // update the same output connection twice. This test ensures the correctness
170   // of the conversion under such condition.
171   Scope s = Scope::NewRootScope();
172   auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
173                                 ops::Placeholder::Shape({2, 1}));
174   // We purposefully choose the name of the root node of each segment, so it'll
175   // process the segment in the downstream first, then, when it tries to update
176   // the edge between the two TRTEngineOps, it'll try to add the same edge
177   // multiple times.
178   auto segment_root_1 = ops::Identity(s.WithOpName("segment_root_b"), input);
179   auto add1 = ops::Add(s.WithOpName("add1"), segment_root_1, segment_root_1);
180   // Add incompatible reshapes that change the batch dimension.
181   auto incompatible =
182       ops::Reshape(s.WithOpName("reshape1"), add1, Input({1, 2}));
183   incompatible =
184       ops::Reshape(s.WithOpName("reshape2"), incompatible, Input({2, 1}));
185 
186   auto add2 = ops::Add(s.WithOpName("add2"), incompatible, add1);
187   auto segment_root_2 = ops::Identity(s.WithOpName("segment_root_a"), add1);
188   auto add3 = ops::Add(s.WithOpName("add3"), add2, segment_root_2);
189   ops::Identity(s.WithOpName("output"), add3);
190 
191   GraphDef output_graph_def;
192   TF_EXPECT_OK(RunConvertGraph(s, &output_graph_def));
193 
194   auto remove_graph_sequence_number = [](std::string node_name) {
195     const std::regex pattern("TRTEngineOp_[0-9]+_");
196     return std::regex_replace(node_name, pattern, "TRTEngineOp_");
197   };
198   int num_trt_ops = 0;
199   for (const NodeDef& node : output_graph_def.node()) {
200     std::string node_name = node.name();
201     if (node.op() != "TRTEngineOp") continue;
202     node_name = remove_graph_sequence_number(node_name);
203     if (node_name == "TRTEngineOp_001") {
204       EXPECT_EQ(1, node.input_size());
205       EXPECT_EQ("input", node.input(0));
206       ++num_trt_ops;
207     } else if (node_name == "TRTEngineOp_000") {
208       EXPECT_EQ(2, node.input_size());
209       EXPECT_EQ("TRTEngineOp_001", remove_graph_sequence_number(node.input(0)));
210       EXPECT_EQ("reshape2", node.input(1));
211       ++num_trt_ops;
212     }
213   }
214   EXPECT_EQ(2, num_trt_ops);
215 }
216 
217 }  // namespace convert
218 }  // namespace tensorrt
219 }  // namespace tensorflow
220 
221 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
222