1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_UTIL_H_ 16 #define TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_UTIL_H_ 17 18 #include <unordered_map> 19 20 #include "absl/types/optional.h" 21 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" 22 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 23 #include "tensorflow/compiler/xla/status_macros.h" 24 #include "tensorflow/core/framework/graph.pb.h" 25 26 namespace tensorflow { 27 28 // Fills in xla_args from the corresponding _Arg nodes in the graph. 29 Status CreateXlaArgs(const Graph& graph, 30 std::vector<XlaCompiler::Argument>* xla_args); 31 32 // Populate xla_args for the given XLA config. 33 void PopulateXlaArgs(const tf2xla::Config& config, 34 std::vector<XlaCompiler::Argument>* xla_args); 35 36 // InitGraph creates a graph based on the graph_def, that may then be converted 37 // to an xla::XlaComputation via ConvertGraphToXla. 38 // 39 // The graph is rewritten with _Arg and _Retval nodes, representing the inputs 40 // and outputs of the function that will be compiled. Each feed id causes a new 41 // _Arg node to be created, where we first collect all existing edges pointing 42 // from the named node's output index, and then rewrite them to point from that 43 // _Arg node instead. Each fetch id causes a new _Retval node to be created, 44 // with a new edge pointing from the named node's output index to that _Retval 45 // node. 46 Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, 47 std::unique_ptr<Graph>* graph); 48 49 } // namespace tensorflow 50 51 #endif // TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_UTIL_H_ 52