• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_
18 
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/graph/graph.h"
23 #include "tensorflow/core/platform/env.h"
24 
25 // TODO(skyewm): can this be combined with EvaluateConstantTensor?
26 
27 namespace tensorflow {
28 
29 // This generator type is used to generate a name for the newly folded node
30 // based on the node's old name.
31 using ConstantFoldNameGenerator =
32     std::function<string(Graph* graph, string old_name)>;
33 
34 // Options specific to constant folding optimizations.
35 struct ConstantFoldingOptions {
36   // If "consider" is not a nullptr, then only constant fold a node "n" if
37   // consider(n) returns true.
38   std::function<bool(const Node*)> consider = nullptr;
39   // If shape_map is not a nullptr, it is a map from node n to a
40   // vector of the (potentially partially-known) shapes of its
41   // outputs.
42   const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map =
43       nullptr;  // not owned
44   // The maximum size of each constant created during constant folding
45   // optimization.
46   int64 max_constant_size_in_bytes = 10 * 1024 * 1024;
47 
48   // A generator for the name suffix of constant folded nodes. A
49   // default id generator that monotonically increases is used if nullptr is
50   // passed.
51   ConstantFoldNameGenerator generate_new_name = nullptr;
52 };
53 
54 // Perform constant folding optimization on "graph".
55 // Looks for nodes in "graph" that can be completely evaluated statically, i.e.,
56 // that are only dependent on constants. Evaluates those nodes on a CPU device
57 // and replaces those nodes with the result of the evaluation.
58 // "partition_device", if non-null, is the device where all the graph nodes are
59 // assumed to execute.
60 // Sets `was_mutated` to true if and only if "graph" has been mutated.
61 // The status is only set to a non-OK state if an unexpected error is hit
62 // running the graph.
63 Status ConstantFold(const ConstantFoldingOptions& opts,
64                     FunctionLibraryRuntime* function_library, Env* env,
65                     const Device* partition_device, Graph* graph,
66                     bool* was_mutated);
67 
68 }  // namespace tensorflow
69 
70 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_
71