Skip to content

Commit 4bcb707

Browse files
authored
Merge pull request #55 from labthings/gevent-pool
Updated task management to better match Gevent pool interface
2 parents d279815 + 9c723c3 commit 4bcb707

File tree

10 files changed

+103
-122
lines changed

10 files changed

+103
-122
lines changed

labthings/core/tasks/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,26 @@
11
__all__ = [
2+
"Pool",
23
"taskify",
34
"tasks",
4-
"dictionary",
5+
"to_dict",
56
"states",
67
"current_task",
78
"update_task_progress",
8-
"cleanup_tasks",
9-
"remove_task",
9+
"cleanup",
10+
"discard_id",
1011
"update_task_data",
1112
"ThreadTerminationError",
1213
]
1314

1415
from .pool import (
16+
Pool,
1517
tasks,
16-
dictionary,
18+
to_dict,
1719
states,
1820
current_task,
1921
update_task_progress,
20-
cleanup_tasks,
21-
remove_task,
22+
cleanup,
23+
discard_id,
2224
update_task_data,
2325
taskify,
2426
)

labthings/core/tasks/pool.py

Lines changed: 49 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,71 @@
11
import logging
22
from functools import wraps
33
from gevent import getcurrent
4+
from gevent.pool import Pool as _Pool, PoolFull
45

56
from .thread import TaskThread
67

7-
from flask import copy_current_request_context, has_request_context
88

9+
class Pool(_Pool):
10+
def __init__(self, size=None):
11+
_Pool.__init__(self, size=size, greenlet_class=TaskThread)
912

10-
class TaskMaster:
11-
def __init__(self, *args, **kwargs):
12-
self._tasks = []
13+
def add(self, greenlet, blocking=True, timeout=None):
14+
"""
15+
Override the default Gevent pool `add` method so that
16+
tasks are not discarded as soon as they finish.
17+
"""
18+
if not self._semaphore.acquire(blocking=blocking, timeout=timeout):
19+
# We failed to acquire the semaphore.
20+
# If blocking was True, then there was a timeout. If blocking was
21+
# False, then there was no capacity. Either way, raise PoolFull.
22+
raise PoolFull()
23+
24+
try:
25+
self.greenlets.add(greenlet)
26+
self._empty_event.clear()
27+
except:
28+
self._semaphore.release()
29+
raise
1330

14-
@property
1531
def tasks(self):
1632
"""
1733
Returns:
1834
list: List of TaskThread objects.
1935
"""
20-
return self._tasks
36+
return list(self.greenlets)
2137

22-
@property
23-
def dict(self):
38+
def states(self):
2439
"""
2540
Returns:
26-
dict: Dictionary of TaskThread objects. Key is TaskThread ID.
41+
dict: Dictionary of TaskThread.state dictionaries. Key is TaskThread ID.
2742
"""
28-
return {str(t.id): t for t in self._tasks}
43+
return {str(t.id): t.state for t in self.greenlets}
2944

30-
@property
31-
def states(self):
45+
def to_dict(self):
3246
"""
3347
Returns:
34-
dict: Dictionary of TaskThread.state dictionaries. Key is TaskThread ID.
48+
dict: Dictionary of TaskThread objects. Key is TaskThread ID.
3549
"""
36-
return {str(t.id): t.state for t in self._tasks}
37-
38-
def new(self, f, *args, **kwargs):
39-
# copy_current_request_context allows threads to access flask current_app
40-
if has_request_context():
41-
target = copy_current_request_context(f)
42-
else:
43-
target = f
44-
task = TaskThread(target=target, args=args, kwargs=kwargs)
45-
self._tasks.append(task)
46-
return task
50+
return {str(t.id): t for t in self.greenlets}
4751

48-
def remove(self, task_id):
49-
for task in self._tasks:
52+
def discard_id(self, task_id):
53+
marked_for_discard = set()
54+
for task in self.greenlets:
5055
if (str(task.id) == str(task_id)) and task.dead:
51-
self._tasks.remove(task)
56+
marked_for_discard.add(task)
57+
58+
for greenlet in marked_for_discard:
59+
self.discard(greenlet)
5260

5361
def cleanup(self):
54-
for i, task in enumerate(self._tasks):
62+
marked_for_discard = set()
63+
for task in self.greenlets:
5564
if task.dead:
56-
# Mark for delection
57-
self._tasks[i] = None
58-
# Remove items marked for deletion
59-
self._tasks = [t for t in self._tasks if t]
60-
61-
62-
# Task management
63-
64-
65-
def tasks():
66-
"""
67-
List of tasks in default taskmaster
68-
Returns:
69-
list: List of tasks in default taskmaster
70-
"""
71-
global DEFAULT_TASK_MASTER
72-
return DEFAULT_TASK_MASTER.tasks
65+
marked_for_discard.add(task)
7366

74-
75-
def dictionary():
76-
"""
77-
Dictionary of tasks in default taskmaster
78-
Returns:
79-
dict: Dictionary of tasks in default taskmaster
80-
"""
81-
global DEFAULT_TASK_MASTER
82-
return DEFAULT_TASK_MASTER.dict
83-
84-
85-
def states():
86-
"""
87-
Dictionary of TaskThread.state dictionaries. Key is TaskThread ID.
88-
Returns:
89-
dict: Dictionary of task states in default taskmaster
90-
"""
91-
global DEFAULT_TASK_MASTER
92-
return DEFAULT_TASK_MASTER.states
93-
94-
95-
def cleanup_tasks():
96-
"""Remove all finished tasks from the task list"""
97-
global DEFAULT_TASK_MASTER
98-
return DEFAULT_TASK_MASTER.cleanup()
99-
100-
101-
def remove_task(task_id: str):
102-
"""Remove a particular task from the task list
103-
104-
Arguments:
105-
task_id {str} -- ID of the target task
106-
"""
107-
global DEFAULT_TASK_MASTER
108-
return DEFAULT_TASK_MASTER.remove(task_id)
67+
for greenlet in marked_for_discard:
68+
self.discard(greenlet)
10969

11070

11171
# Operations on the current task
@@ -161,17 +121,23 @@ def taskify(f):
161121
A decorator that wraps the passed in function
162122
and surpresses exceptions should one occur
163123
"""
124+
global default_pool
164125

165126
@wraps(f)
166127
def wrapped(*args, **kwargs):
167-
task = DEFAULT_TASK_MASTER.new(
128+
task = default_pool.spawn(
168129
f, *args, **kwargs
169130
) # Append to parent object's task list
170-
task.start() # Start the function
171131
return task
172132

173133
return wrapped
174134

175135

176136
# Create our default, protected, module-level task pool
177-
DEFAULT_TASK_MASTER = TaskMaster()
137+
default_pool = Pool()
138+
139+
tasks = default_pool.tasks
140+
to_dict = default_pool.to_dict
141+
states = default_pool.states
142+
cleanup = default_pool.cleanup
143+
discard_id = default_pool.discard_id

labthings/core/tasks/thread.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from gevent import Greenlet, GreenletExit
22
from gevent.thread import get_ident
33
from gevent.event import Event
4+
from flask import copy_current_request_context, has_request_context
45
import datetime
56
import logging
67
import traceback
@@ -18,13 +19,8 @@ class TaskKillException(Exception):
1819

1920

2021
class TaskThread(Greenlet):
21-
def __init__(self, target=None, args=None, kwargs=None):
22+
def __init__(self, target, *args, **kwargs):
2223
Greenlet.__init__(self)
23-
# Handle arguments
24-
if args is None:
25-
args = ()
26-
if kwargs is None:
27-
kwargs = {}
2824

2925
# A UUID for the TaskThread (not the same as the threading.Thread ident)
3026
self._ID = uuid.uuid4() # Task ID
@@ -83,7 +79,12 @@ def update_data(self, data: dict):
8379
self.data.update(data)
8480

8581
def _run(self): # pylint: disable=E0202
86-
return self._thread_proc(self._target)(*self._args, **self._kwargs)
82+
# copy_current_request_context allows threads to access flask current_app
83+
if has_request_context():
84+
target = copy_current_request_context(self._target)
85+
else:
86+
target = self._target
87+
return self._thread_proc(target)(*self._args, **self._kwargs)
8788

8889
def _thread_proc(self, f):
8990
"""

labthings/server/default_views/tasks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def get(self, task_id):
3131
3232
Includes progress and intermediate data.
3333
"""
34-
task_dict = tasks.dictionary()
34+
task_dict = tasks.to_dict()
3535

36-
if not task_id in task_dict:
36+
if task_id not in task_dict:
3737
return abort(404) # 404 Not Found
3838

3939
task = task_dict.get(task_id)
@@ -47,9 +47,9 @@ def delete(self, task_id):
4747
4848
If the task is finished, deletes its entry.
4949
"""
50-
task_dict = tasks.dictionary()
50+
task_dict = tasks.to_dict()
5151

52-
if not task_id in task_dict:
52+
if task_id not in task_dict:
5353
return abort(404) # 404 Not Found
5454

5555
task = task_dict.get(task_id)

tests/test_core_tasks_pool.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,10 @@ def test_tasks_list():
5555

5656
def test_tasks_dict():
5757
assert all(
58-
[
59-
isinstance(task_obj, gevent.Greenlet)
60-
for task_obj in tasks.dictionary().values()
61-
]
58+
[isinstance(task_obj, gevent.Greenlet) for task_obj in tasks.to_dict().values()]
6259
)
6360

64-
assert all([k == str(t.id) for k, t in tasks.dictionary().items()])
61+
assert all([k == str(t.id) for k, t in tasks.to_dict().items()])
6562

6663

6764
def test_task_states():
@@ -80,16 +77,16 @@ def test_task_states():
8077
assert all(k in state for k in state_keys)
8178

8279

83-
def test_remove_task():
80+
def test_discard_id():
8481
def task_func():
8582
pass
8683

8784
task_obj = tasks.taskify(task_func)()
88-
assert str(task_obj.id) in tasks.dictionary()
85+
assert str(task_obj.id) in tasks.to_dict()
8986
task_obj.join()
9087

91-
tasks.remove_task(task_obj.id)
92-
assert not str(task_obj.id) in tasks.dictionary()
88+
tasks.discard_id(task_obj.id)
89+
assert not str(task_obj.id) in tasks.to_dict()
9390

9491

9592
def test_cleanup_task():
@@ -105,5 +102,5 @@ def task_func():
105102
gevent.joinall(tasks.tasks())
106103

107104
assert len(tasks.tasks()) > 0
108-
tasks.cleanup_tasks()
105+
tasks.cleanup()
109106
assert len(tasks.tasks()) == 0

tests/test_core_tasks_thread.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ def test_task_with_args():
1010
def task_func(arg, kwarg=False):
1111
pass
1212

13-
task_obj = thread.TaskThread(
14-
target=task_func, args=("String arg",), kwargs={"kwarg": True}
15-
)
13+
task_obj = thread.TaskThread(task_func, "String arg", kwarg=True)
1614
assert isinstance(task_obj, gevent.Greenlet)
1715
assert task_obj._target == task_func
1816
assert task_obj._args == ("String arg",)
@@ -121,7 +119,7 @@ def test_task_log_without_thread():
121119

122120
def test_task_log_with_incorrect_thread():
123121

124-
task_obj = thread.TaskThread()
122+
task_obj = thread.TaskThread(None)
125123
task_log_handler = thread.ThreadLogHandler(thread=task_obj)
126124

127125
# Should always return False if called from outside the log handlers thread

tests/test_server_decorators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def func():
9191

9292
def test_marshal_task(app_ctx):
9393
def func():
94-
return TaskThread()
94+
return TaskThread(None)
9595

9696
wrapped_func = decorators.marshal_task(func)
9797

@@ -102,7 +102,7 @@ def func():
102102

103103
def test_marshal_task_response_tuple(app_ctx):
104104
def func():
105-
return (TaskThread(), 201, {})
105+
return (TaskThread(None), 201, {})
106106

107107
wrapped_func = decorators.marshal_task(func)
108108

tests/test_server_default_views.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from labthings.core.tasks import taskify, dictionary
1+
from labthings.core import tasks
22

33
import gevent
44

@@ -22,7 +22,7 @@ def test_tasks_list(thing_client):
2222
def task_func():
2323
pass
2424

25-
task_obj = taskify(task_func)()
25+
task_obj = tasks.taskify(task_func)()
2626

2727
with thing_client as c:
2828
response = c.get("/tasks").json
@@ -34,7 +34,7 @@ def test_task_representation(thing_client):
3434
def task_func():
3535
pass
3636

37-
task_obj = taskify(task_func)()
37+
task_obj = tasks.taskify(task_func)()
3838
task_id = str(task_obj.id)
3939

4040
with thing_client as c:
@@ -52,12 +52,12 @@ def task_func():
5252
while True:
5353
gevent.sleep(0)
5454

55-
task_obj = taskify(task_func)()
55+
task_obj = tasks.taskify(task_func)()
5656
task_id = str(task_obj.id)
5757

5858
# Wait for task to start
5959
task_obj.started_event.wait()
60-
assert task_id in dictionary()
60+
assert task_id in tasks.to_dict()
6161

6262
# Send a DELETE request to terminate the task
6363
with thing_client as c:

0 commit comments

Comments
 (0)