• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_UTIL_H_
16 #define TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_UTIL_H_
17 
18 #include <unordered_map>
19 
20 #include "absl/types/optional.h"
21 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
22 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/core/framework/graph.pb.h"
25 
26 namespace tensorflow {
27 
28 // Fills in xla_args from the corresponding _Arg nodes in the graph.
29 Status CreateXlaArgs(const Graph& graph,
30                      std::vector<XlaCompiler::Argument>* xla_args);
31 
32 // Populate xla_args for the given XLA config.
33 void PopulateXlaArgs(const tf2xla::Config& config,
34                      std::vector<XlaCompiler::Argument>* xla_args);
35 
36 // InitGraph creates a graph based on the graph_def, that may then be converted
37 // to an xla::XlaComputation via ConvertGraphToXla.
38 //
39 // The graph is rewritten with _Arg and _Retval nodes, representing the inputs
40 // and outputs of the function that will be compiled.  Each feed id causes a new
41 // _Arg node to be created, where we first collect all existing edges pointing
42 // from the named node's output index, and then rewrite them to point from that
43 // _Arg node instead.  Each fetch id causes a new _Retval node to be created,
44 // with a new edge pointing from the named node's output index to that _Retval
45 // node.
46 Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
47                  std::unique_ptr<Graph>* graph);
48 
49 }  // namespace tensorflow
50 
51 #endif  // TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_UTIL_H_
52