• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_EAGER_TRANSFORM_GRAPH_FUNCTION_H_
16 #define TENSORFLOW_CORE_TFRT_EAGER_TRANSFORM_GRAPH_FUNCTION_H_
17 
18 #include <memory>
19 
20 #include "tensorflow/core/common_runtime/function_body.h"
21 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
22 
23 namespace tfrt {
24 class Device;
25 }  // namespace tfrt
26 
27 namespace tensorflow {
28 
29 class EagerContext;
30 class FunctionDef;
31 class FunctionLibraryDefinition;
32 
33 // Run placer.
34 // When `enable_grappler` is true, also run grappler passes over
35 // the input function, which might add some entries to `func_lib_def`.
36 //
37 // TODO(tfrt-devs): Consider passing in a more expressive compiler options
38 // object such as TFRTCompilerOptions instead of `enable_grappler`, for caller
39 // to configure graph transformation behavior, such as the more granular options
40 // in RewriterConfig proto and even individual grappler pass options like
41 // grappler::ArithmeticOptimizerOptions.
42 Status TransformGraphFunction(const std::string& func_name,
43                               const FunctionDef& fdef,
44                               const std::string& device_name,
45                               const tensorflow::DeviceSet& device_set,
46                               EagerContext* eager_ctx, bool enable_grappler,
47                               std::unique_ptr<FunctionBody>* fbody,
48                               std::unique_ptr<Graph> graph,
49                               tfrt::ArrayRef<const tfrt::Device*> input_devices,
50                               FunctionLibraryDefinition* func_lib_def);
51 
52 }  // namespace tensorflow
53 
54 #endif  // TENSORFLOW_CORE_TFRT_EAGER_TRANSFORM_GRAPH_FUNCTION_H_
55