1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16 17"""Utilities for using generic resources.""" 18# pylint: disable=g-bad-name 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import collections 24import os 25 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.util import tf_should_use 32 33 34_Resource = collections.namedtuple("_Resource", 35 ["handle", "create", "is_initialized"]) 36 37 38def register_resource(handle, create_op, is_initialized_op, is_shared=True): 39 """Registers a resource into the appropriate collections. 40 41 This makes the resource findable in either the shared or local resources 42 collection. 43 44 Args: 45 handle: op which returns a handle for the resource. 46 create_op: op which initializes the resource. 47 is_initialized_op: op which returns a scalar boolean tensor of whether 48 the resource has been initialized. 49 is_shared: if True, the resource gets added to the shared resource 50 collection; otherwise it gets added to the local resource collection. 51 52 """ 53 resource = _Resource(handle, create_op, is_initialized_op) 54 if is_shared: 55 ops.add_to_collection(ops.GraphKeys.RESOURCES, resource) 56 else: 57 ops.add_to_collection(ops.GraphKeys.LOCAL_RESOURCES, resource) 58 59 60def shared_resources(): 61 """Returns resources visible to all tasks in the cluster.""" 62 return ops.get_collection(ops.GraphKeys.RESOURCES) 63 64 65def local_resources(): 66 """Returns resources intended to be local to this session.""" 67 return ops.get_collection(ops.GraphKeys.LOCAL_RESOURCES) 68 69 70def report_uninitialized_resources(resource_list=None, 71 name="report_uninitialized_resources"): 72 """Returns the names of all uninitialized resources in resource_list. 73 74 If the returned tensor is empty then all resources have been initialized. 75 76 Args: 77 resource_list: resources to check. If None, will use shared_resources() + 78 local_resources(). 79 name: name for the resource-checking op. 80 81 Returns: 82 Tensor containing names of the handles of all resources which have not 83 yet been initialized. 84 85 """ 86 if resource_list is None: 87 resource_list = shared_resources() + local_resources() 88 with ops.name_scope(name): 89 # Run all operations on CPU 90 local_device = os.environ.get( 91 "TF_DEVICE_FOR_UNINITIALIZED_VARIABLE_REPORTING", "/cpu:0") 92 with ops.device(local_device): 93 if not resource_list: 94 # Return an empty tensor so we only need to check for returned tensor 95 # size being 0 as an indication of model ready. 96 return array_ops.constant([], dtype=dtypes.string) 97 # Get a 1-D boolean tensor listing whether each resource is initialized. 98 variables_mask = math_ops.logical_not( 99 array_ops.stack([r.is_initialized for r in resource_list])) 100 # Get a 1-D string tensor containing all the resource names. 101 variable_names_tensor = array_ops.constant( 102 [s.handle.name for s in resource_list]) 103 # Return a 1-D tensor containing all the names of uninitialized resources. 104 return array_ops.boolean_mask(variable_names_tensor, variables_mask) 105 106 107@tf_should_use.should_use_result 108def initialize_resources(resource_list, name="init"): 109 """Initializes the resources in the given list. 110 111 Args: 112 resource_list: list of resources to initialize. 113 name: name of the initialization op. 114 115 Returns: 116 op responsible for initializing all resources. 117 """ 118 if resource_list: 119 return control_flow_ops.group(*[r.create for r in resource_list], name=name) 120 return control_flow_ops.no_op(name=name) 121