• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This file creates a binary that can run any registered optimization pass.
16 // ./xla_gpu_opt  --input_file_path=/tmp/input.pbtxt
17 // --output_file_path=/tmp/output.pbtxt
18 // --optimization_pass=NameOfGraphOptimizationPass
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/platform/init_main.h"
25 #include "tensorflow/core/protobuf/config.pb.h"
26 #include "tensorflow/core/util/command_line_flags.h"
27 #include "tensorflow/tools/optimization/optimization_pass_runner.h"
28 
29 namespace tensorflow {
30 namespace {
RealMain(int argc,char ** argv)31 Status RealMain(int argc, char** argv) {
32   string input_file_path;
33   string output_file_path;
34   string optimization_pass;
35 
36   const std::vector<Flag> flag_list = {
37       Flag("input_file_path", &input_file_path, "Location of the input graph."),
38       Flag("output_file_path", &output_file_path,
39            "Location to write the resulting graph."),
40       // For now only a single optimization pass can be run.
41       Flag("optimization_pass", &optimization_pass,
42            "Which optimization pass to run."),
43   };
44   if (!Flags::Parse(&argc, argv, flag_list)) {
45     return errors::FailedPrecondition("Invalid flags passed");
46   }
47   port::InitMain(argv[0], &argc, &argv);
48 
49   if (input_file_path.empty()) {
50     return errors::FailedPrecondition("input_file_path is a required flag.");
51   }
52   if (output_file_path.empty()) {
53     return errors::FailedPrecondition("output_file_path is a required flag.");
54   }
55   if (optimization_pass.empty()) {
56     return errors::FailedPrecondition("optimization_pass is a required flag.");
57   }
58 
59   GraphDef graphdef_input;
60   TF_RETURN_IF_ERROR(
61       ReadTextProto(Env::Default(), input_file_path, &graphdef_input));
62 
63   tensorflow::OptimizationPassRunner runner;
64 
65   // Most machines in our servers currently use 8 gpus. There is nothing special
66   // about this number and it can be decreased or increased to test other
67   // configurations.
68   TF_RETURN_IF_ERROR(runner.AddCpus(8));
69   TF_RETURN_IF_ERROR(runner.AddGpus(8));
70 
71   // This binary is used to test TF:XLA behavior, so turn on auto_jit.
72   TF_RETURN_IF_ERROR(runner.SetJitLevel(tensorflow::OptimizerOptions::ON_2));
73   GraphDef graphdef_output;
74   TF_RETURN_IF_ERROR(runner.Run(optimization_pass, std::move(graphdef_input),
75                                 &graphdef_output));
76   return WriteTextProto(Env::Default(), output_file_path, graphdef_output);
77 }
78 }  // namespace
79 }  // namespace tensorflow
80 
main(int argc,char ** argv)81 int main(int argc, char** argv) {
82   TF_CHECK_OK(tensorflow::RealMain(argc, argv));
83   return 0;
84 }
85