• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
17 
18 #include "tensorflow/compiler/jit/cluster_scoping_pass.h"
19 #include "tensorflow/core/common_runtime/device_factory.h"
20 #include "tensorflow/core/lib/gtl/cleanup.h"
21 #include "tensorflow/core/public/session_options.h"
22 
23 namespace tensorflow {
MarkForCompilation(std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def,MarkForCompilationPassTestHelper::Options options)24 /*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
25     std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
26     MarkForCompilationPassTestHelper::Options options) {
27   // Assign all unassigned nodes to the CPU device.
28   static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
29   for (Node* n : (*graph)->nodes()) {
30     if (n->assigned_device_name().empty()) {
31       n->set_assigned_device_name(kCpuDevice);
32     }
33   }
34 
35   SessionOptions session_options;
36   if (options.enable_global_jit) {
37     session_options.config.mutable_graph_options()
38         ->mutable_optimizer_options()
39         ->set_global_jit_level(OptimizerOptions::ON_2);
40   }
41 
42   // Call AddDevices to register the XLA devices.
43   //
44   // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to
45   // make this more direct, but probably not worth it solely for this test.
46   std::vector<std::unique_ptr<Device>> devices;
47   TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(session_options, "", &devices));
48 
49   GraphOptimizationPassOptions opt_options;
50   opt_options.graph = graph;
51   opt_options.session_options = &session_options;
52   opt_options.flib_def = flib_def;
53 
54   if (options.enable_cluster_scoping) {
55     ClusterScopingPass cluster_scoping_pass;
56     TF_RETURN_IF_ERROR(cluster_scoping_pass.Run(opt_options));
57   }
58 
59   MarkForCompilationPass mark_for_compilation_pass;
60   return mark_for_compilation_pass.RunForTest(
61       opt_options,
62       /*disable_deadness_analysis=*/options.disable_deadness_analysis,
63       /*deterministic_cluster_names=*/options.deterministic_cluster_names);
64 }
65 
MarkForCompilation(std::unique_ptr<Graph> * graph,MarkForCompilationPassTestHelper::Options options)66 /*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
67     std::unique_ptr<Graph>* graph,
68     MarkForCompilationPassTestHelper::Options options) {
69   FunctionDefLibrary flib;
70   FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
71   return MarkForCompilation(graph, &flib_def, options);
72 }
73 }  // namespace tensorflow
74