1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""smart_cond and related utilties.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python import pywrap_tensorflow as c_api 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_util 24from tensorflow.python.ops import control_flow_ops 25 26 27def smart_cond(pred, true_fn=None, false_fn=None, name=None): 28 """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. 29 30 If `pred` is a bool or has a constant value, we return either `true_fn()` 31 or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. 32 33 Arguments: 34 pred: A scalar determining whether to return the result of `true_fn` or 35 `false_fn`. 36 true_fn: The callable to be performed if pred is true. 37 false_fn: The callable to be performed if pred is false. 38 name: Optional name prefix when using `tf.cond`. 39 40 Returns: 41 Tensors returned by the call to either `true_fn` or `false_fn`. 42 43 Raises: 44 TypeError: If `true_fn` or `false_fn` is not callable. 45 """ 46 if not callable(true_fn): 47 raise TypeError("`true_fn` must be callable.") 48 if not callable(false_fn): 49 raise TypeError("`false_fn` must be callable.") 50 51 pred_value = smart_constant_value(pred) 52 if pred_value is not None: 53 if pred_value: 54 return true_fn() 55 else: 56 return false_fn() 57 else: 58 return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn, 59 name=name) 60 61 62def smart_constant_value(pred): 63 """Return the bool value for `pred`, or None if `pred` had a dynamic value. 64 65 Arguments: 66 pred: A scalar, either a Python bool or tensor. 67 68 Returns: 69 True or False if `pred` has a constant boolean value, None otherwise. 70 71 Raises: 72 TypeError: If `pred` is not a Tensor or bool. 73 """ 74 if pred in {0, 1}: # Accept 1/0 as valid boolean values 75 pred_value = bool(pred) 76 elif isinstance(pred, bool): 77 pred_value = pred 78 elif isinstance(pred, ops.Tensor): 79 pred_value = tensor_util.constant_value(pred) 80 # TODO(skyewm): consider folding this into tensor_util.constant_value. 81 # pylint: disable=protected-access 82 if pred_value is None: 83 pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph, 84 pred._as_tf_output()) 85 # pylint: enable=protected-access 86 87 else: 88 raise TypeError("`pred` must be a Tensor, or a Python bool, or 1 or 0. " 89 "Found instead: %s" % pred) 90 return pred_value 91 92 93def smart_case(pred_fn_pairs, default=None, exclusive=False, name="smart_case"): 94 """Like tf.case, except attempts to statically evaluate predicates. 95 96 If any predicate in `pred_fn_pairs` is a bool or has a constant value, the 97 associated callable will be called or omitted depending on its value. 98 Otherwise this functions like tf.case. 99 100 Args: 101 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a 102 callable which returns a list of tensors. 103 default: Optional callable that returns a list of tensors. 104 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 105 name: A name for this operation (optional). 106 107 Returns: 108 The tensors returned by the first pair whose predicate evaluated to True, or 109 those returned by `default` if none does. 110 111 Raises: 112 TypeError: If `pred_fn_pairs` is not a list/dictionary. 113 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 114 TypeError: If `fns[i]` is not callable for any i, or `default` is not 115 callable. 116 """ 117 return control_flow_ops._case_helper( # pylint: disable=protected-access 118 smart_cond, pred_fn_pairs, default, exclusive, name, 119 allow_python_preds=True) 120