1# -*- coding: utf-8 -*- 2# Copyright 2013 Google Inc. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15"""Utility classes for the parallelism framework.""" 16 17from __future__ import absolute_import 18 19import threading 20 21 22class AtomicDict(object): 23 """Thread-safe (and optionally process-safe) dictionary protected by a lock. 24 25 If a multiprocessing.Manager is supplied on init, the dictionary is 26 both process and thread safe. Otherwise, it is only thread-safe. 27 """ 28 29 def __init__(self, manager=None): 30 """Initializes the dict. 31 32 Args: 33 manager: multiprocessing.Manager instance (required for process safety). 34 """ 35 if manager: 36 self.lock = manager.Lock() 37 self.dict = manager.dict() 38 else: 39 self.lock = threading.Lock() 40 self.dict = {} 41 42 def __getitem__(self, key): 43 with self.lock: 44 return self.dict[key] 45 46 def __setitem__(self, key, value): 47 with self.lock: 48 self.dict[key] = value 49 50 # pylint: disable=invalid-name 51 def get(self, key, default_value=None): 52 with self.lock: 53 return self.dict.get(key, default_value) 54 55 def delete(self, key): 56 with self.lock: 57 del self.dict[key] 58 59 def Increment(self, key, inc, default_value=0): 60 """Atomically updates the stored value associated with the given key. 61 62 Performs the atomic equivalent of 63 dict[key] = dict.get(key, default_value) + inc. 64 65 Args: 66 key: lookup key for the value of the first operand of the "+" operation. 67 inc: Second operand of the "+" operation. 68 default_value: Default value if there is no existing value for the key. 69 70 Returns: 71 Incremented value. 72 """ 73 with self.lock: 74 val = self.dict.get(key, default_value) + inc 75 self.dict[key] = val 76 return val 77