1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_ 18 19 #include "tensorflow/core/common_runtime/device.h" 20 #include "tensorflow/core/framework/function.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/graph/graph.h" 23 #include "tensorflow/core/platform/env.h" 24 25 // TODO(skyewm): can this be combined with EvaluateConstantTensor? 26 27 namespace tensorflow { 28 29 // This generator type is used to generate a name for the newly folded node 30 // based on the node's old name. 31 using ConstantFoldNameGenerator = 32 std::function<string(Graph* graph, string old_name)>; 33 34 // Options specific to constant folding optimizations. 35 struct ConstantFoldingOptions { 36 // If "consider" is not a nullptr, then only constant fold a node "n" if 37 // consider(n) returns true. 38 std::function<bool(const Node*)> consider = nullptr; 39 // If shape_map is not a nullptr, it is a map from node n to a 40 // vector of the (potentially partially-known) shapes of its 41 // outputs. 42 const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map = 43 nullptr; // not owned 44 // The maximum size of each constant created during constant folding 45 // optimization. 46 int64_t max_constant_size_in_bytes = 10 * 1024 * 1024; 47 48 // A generator for the name suffix of constant folded nodes. A 49 // default id generator that monotonically increases is used if nullptr is 50 // passed. 51 ConstantFoldNameGenerator generate_new_name = nullptr; 52 }; 53 54 // Perform constant folding optimization on "graph". 55 // Looks for nodes in "graph" that can be completely evaluated statically, i.e., 56 // that are only dependent on constants. Evaluates those nodes on a CPU device 57 // and replaces those nodes with the result of the evaluation. 58 // "partition_device", if non-null, is the device where all the graph nodes are 59 // assumed to execute. 60 // Sets `was_mutated` to true if and only if "graph" has been mutated. 61 // The status is only set to a non-OK state if an unexpected error is hit 62 // running the graph. 63 Status ConstantFold(const ConstantFoldingOptions& opts, 64 FunctionLibraryRuntime* function_library, Env* env, 65 const Device* partition_device, Graph* graph, 66 bool* was_mutated); 67 68 } // namespace tensorflow 69 70 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_ 71