Source code for kiel.zookeeper.shared_set
import json
import logging
log = logging.getLogger(__name__)
[docs]class SharedSet(object):
"""
A simple "set" construct in Zookeeper with locking and change callbacks.
Used by the Zookeeper-based ``GroupedConsumer`` to represent the shared set
of topic partitions divvied up among the group.
"""
def __init__(self, client, path, on_change):
self.client = client
self.path = path
self.on_change = on_change
@property
def lock_path(self):
"""
Property representing the znode path of the shared lock.
"""
return self.path + "/lock"
[docs] def start(self):
"""
Creates the set's znode path and attaches the data-change callback.
"""
self.client.ensure_path(self.path)
@self.client.DataWatch(self.path)
def set_changed(data, stat):
if data is not None:
data = self.deserialize(data)
self.on_change(data)
[docs] def add_items(self, new_items):
"""
Updates the shared set's data with the given new items added.
If all of the given items are already present, no data is updated.
Works entirely behind a zookeeper lock to combat resource contention
among sharers of the set.
"""
with self.client.Lock(self.lock_path):
existing_items = self.deserialize(self.client.get(self.path)[0])
if not existing_items:
existing_items = set()
if new_items.issubset(existing_items):
return
existing_items.update(new_items)
self.client.set(
self.path, self.serialize(existing_items)
)
[docs] def remove_items(self, old_items):
"""
Updates the shared set's data with the given items removed.
If none of the given items are present, no data is updated.
Works entirely behind a zookeeper lock to combat resource contention
among sharers of the set.
"""
with self.client.Lock(self.lock_path):
existing_items = self.deserialize(self.client.get(self.path)[0])
if old_items.isdisjoint(existing_items):
return
existing_items.difference_update(old_items)
self.client.set(self.path, self.serialize(existing_items))
[docs] def serialize(self, data):
"""
Serializes the set data as a list in a JSON string.
"""
return json.dumps(list(data))
[docs] def deserialize(self, data):
"""
Parses a given JSON string as a list, converts to a python set.
"""
return set(json.loads(data or "[]"))