- commit
- 0df8fb90d90a5f881d9815ac411c82c769ee5a00
- parent
- d3f5e579b6fc9897910ec87023a14366b298faad
- Author
- Tobias Bengfort <tobias.bengfort@posteo.de>
- Date
- 2026-02-27 15:55
turn TaskGroup into a context manager
This is wild and I am honestly not 100% sold that this is worth it.
Feature-wise, the main difference from the previous version is that the
subtasks already start executing before the inner block is done.
A similar thing was already possible before by using an additional
function:
```python
async def foo(tg):
tg.add_task(sleep(1))
sleep(1)
tg = TaskGroup()
tg.add_task(foo(tg))
await tg
```
But I must admit that the context manager syntax is nicer:
```python
async with TaskGroup() as tg:
tg.add_task(sleep(1))
sleep(1)
```
Unfortunately, the implementation is a real mind bender:
When we enter the TaskGroup, we want to wrap the parent task. In order
to do that, we first need to get a reference to it. The way we do that
is a bit hacky: We usually use the async/await syntax to pass conditions
and available files back and forth between the coroutines and our
plumbing. Now we abuse the same system to pass a task into the
coroutine.
Now we put the generator of the coroutine that is currently running in a
new task and add it to the TaskGroup. The parent task instead receives
the generator of the `wrapper()` coroutine.
The next time we await, a condition will be yielded to the parent task,
which now references a completely different generator. In addition, the
task that now wraps this coroutine will never receive the condition.
Fortunately, there is a simple fix: We can await
`Condition(time=-math.inf)` in both coroutines to make everything line
up.
After that switcharoo is done, we can `await self` to wait for the
TaskGroup to complete. As part of that, we will at some point exit the
TaskGroup. At that point we abuse the async/await syntax a second time
to raise an exception, which will finish the task without stopping the
generator.
Once the TaskGroup has finished executing, we can attach the generator
back to the parent task and yield control back to it. It will resume
execution of its original generator inside `__aexit__()`. `wrapper()` is
not referenced anywhere anymore, so it will be garbage collected.
Diffstat
| M | README.md | 31 | +++++++++++++++++++++++++++++++ |
| M | tests.py | 37 | +++++++++++++++++-------------------- |
| M | xiio.py | 44 | +++++++++++++++++++++++++++++++++++++++++--- |
3 files changed, 89 insertions, 23 deletions
diff --git a/README.md b/README.md
@@ -31,6 +31,37 @@ async def main(): 31 31 xiio.run(main()) 32 32 ``` 33 33 -1 34 ## Structured Concurrency -1 35 -1 36 Similar to [nurseries in -1 37 trio](https://vorpus.org/blog/notes-on-structured-concurrency-or-go-statement-considered-harmful/) -1 38 and [task groups in -1 39 asyncio](https://docs.python.org/3/library/asyncio-task.html#asyncio.TaskGroup), -1 40 xiio provides a low level primitive that controls the lifetime of subtasks. -1 41 For example, `gather()` is just a higher level abstraction on top of that: -1 42 -1 43 ```python -1 44 async def gather(coros): -1 45 async with TaskGroup() as tg: -1 46 tasks = [tg.add_task(coro) for coro in coros] -1 47 return [task.result for task in tasks] -1 48 ``` -1 49 -1 50 Task groups in xiio have the following properties: -1 51 -1 52 - All subtasks are guaranteed to have finished when the task group exits. -1 53 - Subtasks are not started immediately. They have a chance to get started the -1 54 next time the main task awaits. -1 55 - If any task in a task group raises an exception, a `xiio.CancelledError` is -1 56 raised in all other tasks. The tasks are then responsible for cleaning up -1 57 quickly. They may still await async functions if necessary. -1 58 - Any exceptions that are raised after cancellation are lost. Only the first -1 59 one is raised after cleanup is done. -1 60 - Tasks are removed from the task group once they are done. If you need their -1 61 results, keep the reference that is returned by `TaskGroup.add_task()`. -1 62 - It is possible to add new tasks while the task group is already running, -1 63 and even after cancellation. -1 64 34 65 ## Design 35 66 36 67 I spent quite some time creating meaningful commits. So if you want to
diff --git a/tests.py b/tests.py
@@ -130,40 +130,37 @@ class TestTaskGroup(XiioTestCase): 130 130 await xiio.sleep(seconds) 131 131 future.set_result(None) 132 132133 -1 async def foo(tg):134 -1 future = xiio.Future()135 -1 tg.add_task(set_result_later(0.1, future))136 -1 await future137 -1138 -1 tg = xiio.TaskGroup()139 -1 tg.add_task(foo(tg))140 133 with self.assert_duration(0.1):141 -1 await tg-1 134 async with xiio.TaskGroup() as tg: -1 135 future = xiio.Future() -1 136 tg.add_task(set_result_later(0.1, future)) -1 137 await future -1 138 -1 139 async def test_exception_in_inner_block(self): -1 140 with self.assert_duration(0): -1 141 with self.assertRaises(ValueError): -1 142 async with xiio.TaskGroup() as tg: -1 143 tg.add_task(xiio.sleep(0.3)) -1 144 raise ValueError 142 145143 -1 async def test_starts_tasks_on_await(self):-1 146 async def test_starts_tasks_on_next_pause(self): 144 147 stack = [] 145 148 146 149 async def foo(tg): 147 150 stack.append(1) 148 151149 -1 tg = xiio.TaskGroup()150 -1 tg.add_task(foo(tg))151 -1 await xiio.sleep(0.1)152 -1 self.assertEqual(stack, [])153 -1 await tg154 -1 self.assertEqual(stack, [1])-1 152 async with xiio.TaskGroup() as tg: -1 153 tg.add_task(foo(tg)) -1 154 await xiio.sleep(0.1) -1 155 self.assertEqual(stack, [1]) 155 156 156 157 async def test_removes_finished_tasks(self):157 -1 async def foo(tg):-1 158 async with xiio.TaskGroup() as tg: 158 159 task = tg.add_task(xiio.sleep(0.1)) 159 160 self.assertIn(task, tg.tasks) 160 161 await xiio.sleep(0.2) 161 162 self.assertNotIn(task, tg.tasks) 162 163163 -1 tg = xiio.TaskGroup()164 -1 tg.add_task(foo(tg))165 -1 await tg166 -1167 164 168 165 class TestGather(XiioTestCase): 169 166 async def test_sync_values(self):
diff --git a/xiio.py b/xiio.py
@@ -65,6 +65,17 @@ class Condition: 65 65 return {key.fd: events for key, events in selected} 66 66 67 67 -1 68 class ThrowCondition(Condition): -1 69 def __init__(self, exc: BaseException) -> None: -1 70 super().__init__() -1 71 self.exc = exc -1 72 -1 73 -1 74 class GetTaskCondition(Condition): -1 75 def __init__(self): -1 76 super().__init__(time=-math.inf) -1 77 -1 78 68 79 async def sleep(seconds: float) -> None: 69 80 await Condition(time=time.monotonic() + seconds) 70 81 @@ -131,6 +142,14 @@ class Task(typing.Generic[T]): 131 142 elif self.condition.fulfilled(state): 132 143 self._condition = self.gen.send(state) 133 144 -1 145 while isinstance(self._condition, GetTaskCondition): -1 146 self._condition = self.gen.send(typing.cast(Files, self)) -1 147 -1 148 if isinstance(self._condition, ThrowCondition): -1 149 exc = self._condition.exc -1 150 self._condition = None -1 151 raise exc -1 152 134 153 def cancel(self) -> None: 135 154 self._cancel_soon = True 136 155 self._condition = None @@ -173,14 +192,33 @@ class TaskGroup(typing.Generic[T]): 173 192 self.tasks.remove(task) 174 193 self.cancel(e) 175 194 -1 195 async def __aenter__(self) -> 'TaskGroup[T]': -1 196 parent_task = typing.cast(Task[T], await GetTaskCondition()) -1 197 gen = parent_task.gen -1 198 -1 199 async def wrapper(): -1 200 await Condition(time=-math.inf) -1 201 await self -1 202 parent_task.gen = gen -1 203 parent_task._condition = None -1 204 await Condition(time=-math.inf) -1 205 -1 206 self.tasks.append(Task(gen)) -1 207 parent_task.gen = typing.cast(typing.Any, wrapper().__await__()) -1 208 next(parent_task.gen) -1 209 await Condition(time=-math.inf) -1 210 -1 211 return self -1 212 -1 213 async def __aexit__(self, exc_type, exc: BaseException | None, traceback) -> None: -1 214 await ThrowCondition(exc or StopIteration()) 176 215 if self.exc: 177 216 raise self.exc 178 217 179 218 180 219 async def gather(coros: list[Coro[T]]) -> list[T]:181 -1 tg = TaskGroup()182 -1 tasks = [tg.add_task(coro) for coro in coros]183 -1 await tg-1 220 async with TaskGroup() as tg: -1 221 tasks = [tg.add_task(coro) for coro in coros] 184 222 return [typing.cast(T, task.result) for task in tasks] 185 223 186 224