• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
17 
18 #include "tensorflow/compiler/jit/flags.h"
19 
20 namespace tensorflow {
21 
GetUserRequest(absl::optional<ConfigProto> config_proto)22 static ConfigProto::Experimental::MlirBridgeRollout GetUserRequest(
23     absl::optional<ConfigProto> config_proto) {
24   // TF1 graphs that do not override Sessions's ConfigProto and TF2 graphs
25   // can enable/disable the graph via tf_mlir_enable_mlir_bridge.
26   auto tf_mlir_enable_mlir_bridge =
27       GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
28   if (tf_mlir_enable_mlir_bridge !=
29       ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) {
30     return tf_mlir_enable_mlir_bridge;
31   }
32 
33   // If a ConfigProto was not passed in, we can assume the caller is
34   // checking if TF2 graph should have the bridge enabled / disabled. In that
35   // case, we have already checked tf_mlir_enable_mlir_bridge so it is safe to
36   // return UNSPECIFIED here.
37   if (!config_proto.has_value()) {
38     return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
39   }
40 
41   // TF1 graphs that do override Session's ConfigProto and set
42   // ConfigProto's enable_mlir_bridge or mlir_bridge_rollout fields will not
43   // update tf_mlir_enable_mlir_bridge so check their values.
44 
45   // ConfigProto's enable_mlir_bridge defaults to false so only respect it
46   // when it is true.
47   if (config_proto.value().experimental().enable_mlir_bridge()) {
48     return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
49   }
50   return config_proto.value().experimental().mlir_bridge_rollout();
51 }
52 
GetMlirBridgeRolloutPolicy(const tensorflow::Graph & graph,absl::optional<ConfigProto> config_proto,bool uses_uninitialized_resource_args,bool record_stats)53 MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
54     const tensorflow::Graph& graph, absl::optional<ConfigProto> config_proto,
55     bool uses_uninitialized_resource_args, bool record_stats) {
56   switch (GetUserRequest(config_proto)) {
57     case ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED:
58       return MlirBridgeRolloutPolicy::kEnabledByUser;
59     case ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED:
60       return MlirBridgeRolloutPolicy::kDisabledByUser;
61     default:
62       // User did not explicitly enable or disable the bridge. For now, disable
63       // the bridge.
64       return MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis;
65   }
66 }
67 
68 }  // namespace tensorflow
69