• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_FLAGS_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_FLAGS_H_
18 
19 #include "absl/container/flat_hash_set.h"
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/ADT/StringMap.h"
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/lib/core/status.h"
26 
27 namespace tensorflow {
28 
29 struct ArrayInfo {
30   // The node type when the input node is imported. Typically needs to be
31   // specified when passing arbitrary nodes (some node attributes are removed).
32   DataType imported_dtype;
33 
34   // Node "shape" attribute value.
35   TensorShapeProto shape;
36 };
37 
38 struct GraphImportConfig {
39   using InputArrays =
40       llvm::MapVector<string, ArrayInfo, llvm::StringMap<unsigned>>;
41   // Maps input node names to node data types and shapes.
42   InputArrays inputs;
43   // name:index strings for the data outputs.
44   std::vector<string> outputs;
45   // name strings for the control outputs. This is currently only used when
46   // `graph_as_function` is set.
47   std::vector<string> control_outputs;
48   // Setting prune_unused_nodes to true, would prune unreachable nodes if
49   // output_arrays is specified.
50   bool prune_unused_nodes = false;
51   // If true, inputs of type LegacyFedInput are replaced with Placeholder ops.
52   // LegacyFedInput ops have two outputs unlike Placeholder which has only one
53   // output, so if both outputs of the LegacyFedInput ops are used then returns
54   // an error.
55   bool convert_legacy_fed_inputs = false;
56   // If true, the main graph will be treated as a function.
57   bool graph_as_function = false;
58   // If true, upgrade legacy features of the graph (for instance, functionalize
59   // control-flow).
60   bool upgrade_legacy = false;
61 };
62 
63 struct GraphExportConfig {
64   // Whether to export shape attribute for the NodeDefs in the GraphDef.
65   bool export_shapes = true;
66   // Whether to export library field in the GraphDef.
67   bool export_library = true;
68   // Whether to export debug original node name in the GraphDef.
69   bool export_debug_info = true;
70   // If true, the main graph will be treated as a function.
71   bool graph_as_function = false;
72 };
73 
74 // Parses the command line flag strings to the specification of nodes in
75 // the Graph.
76 Status ParseOutputArrayInfo(absl::string_view array_names,
77                             std::vector<string>* outputs);
78 
79 Status ParseOutputArrayInfo(const std::vector<string>& output_names,
80                             std::vector<string>* outputs);
81 
82 // Parses the command line flag strings to the specification of nodes in
83 // the Graph. `data_types` input string can be empty since the flag is optional.
84 Status ParseInputArrayInfo(absl::string_view array_names,
85                            absl::string_view data_types,
86                            absl::string_view shapes,
87                            GraphImportConfig::InputArrays* inputs);
88 
89 Status ParseInputArrayInfo(const std::vector<string>& node_names,
90                            const std::vector<string>& node_dtypes,
91                            const std::vector<std::vector<int>>& node_shapes,
92                            GraphImportConfig::InputArrays* inputs);
93 }  // namespace tensorflow
94 
95 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_FLAGS_H_
96