- commit
- d3f5e579b6fc9897910ec87023a14366b298faad
- parent
- 919c743e81d34b2eba2473c60d5026d3290d0679
- Author
- Tobias Bengfort <tobias.bengfort@posteo.de>
- Date
- 2026-02-26 22:33
add TaskGroup a more general multiplexing primitive that allows to add tasks after it has already started.
Diffstat
| M | tests.py | 41 | +++++++++++++++++++++++++++++++++++++++++ |
| M | xiio.py | 64 | ++++++++++++++++++++++++++++++++++++++----------------------- |
2 files changed, 81 insertions, 24 deletions
diff --git a/tests.py b/tests.py
@@ -124,6 +124,47 @@ class TestFuture(XiioTestCase): 124 124 await future 125 125 126 126 -1 127 class TestTaskGroup(XiioTestCase): -1 128 async def test_add_tasks_while_running(self): -1 129 async def set_result_later(seconds, future): -1 130 await xiio.sleep(seconds) -1 131 future.set_result(None) -1 132 -1 133 async def foo(tg): -1 134 future = xiio.Future() -1 135 tg.add_task(set_result_later(0.1, future)) -1 136 await future -1 137 -1 138 tg = xiio.TaskGroup() -1 139 tg.add_task(foo(tg)) -1 140 with self.assert_duration(0.1): -1 141 await tg -1 142 -1 143 async def test_starts_tasks_on_await(self): -1 144 stack = [] -1 145 -1 146 async def foo(tg): -1 147 stack.append(1) -1 148 -1 149 tg = xiio.TaskGroup() -1 150 tg.add_task(foo(tg)) -1 151 await xiio.sleep(0.1) -1 152 self.assertEqual(stack, []) -1 153 await tg -1 154 self.assertEqual(stack, [1]) -1 155 -1 156 async def test_removes_finished_tasks(self): -1 157 async def foo(tg): -1 158 task = tg.add_task(xiio.sleep(0.1)) -1 159 self.assertIn(task, tg.tasks) -1 160 await xiio.sleep(0.2) -1 161 self.assertNotIn(task, tg.tasks) -1 162 -1 163 tg = xiio.TaskGroup() -1 164 tg.add_task(foo(tg)) -1 165 await tg -1 166 -1 167 127 168 class TestGather(XiioTestCase): 128 169 async def test_sync_values(self): 129 170 async def return_immediately(value):
diff --git a/xiio.py b/xiio.py
@@ -136,35 +136,51 @@ class Task(typing.Generic[T]): 136 136 self._condition = None 137 137 138 138139 -1 async def gather(coros: list[Coro[T]]) -> list[T]:140 -1 tasks = [Task(coro.__await__()) for coro in coros]141 -1 remaining = tasks[:]142 -1 exc = None143 -1144 -1 while remaining:145 -1 try:146 -1 state = await Condition.combine(147 -1 [task.condition for task in remaining]148 -1 )149 -1 except BaseException as e:150 -1 state = e-1 139 class TaskGroup(typing.Generic[T]): -1 140 def __init__(self) -> None: -1 141 self.tasks: list[Task[T]] = [] -1 142 self.exc: BaseException | None = None -1 143 -1 144 def add_task(self, coro: Coro[T]) -> Task[T]: -1 145 task = Task(coro.__await__()) -1 146 self.tasks.append(task) -1 147 return task 151 148152 -1 for task in remaining[:]:-1 149 def cancel(self, exc: BaseException) -> None: -1 150 if not self.exc: -1 151 self.exc = exc -1 152 for task in self.tasks: -1 153 task.cancel() -1 154 -1 155 def __await__(self) -> Gen[None]: -1 156 while self.tasks: 153 157 try:154 -1 task.resume(state)155 -1 except StopIteration as e:156 -1 remaining.remove(task)157 -1 task.result = e.value-1 158 state = yield Condition.combine( -1 159 [task.condition for task in self.tasks] -1 160 ) 158 161 except BaseException as e:159 -1 remaining.remove(task)160 -1 if not exc:161 -1 exc = e162 -1 for task in remaining:163 -1 task.cancel()-1 162 state = e -1 163 -1 164 for task in self.tasks[:]: -1 165 try: -1 166 task.resume(state) -1 167 except StopIteration as e: -1 168 self.tasks.remove(task) -1 169 task.result = e.value -1 170 except CancelledError: -1 171 self.tasks.remove(task) -1 172 except BaseException as e: -1 173 self.tasks.remove(task) -1 174 self.cancel(e) 164 175165 -1 if exc:166 -1 raise exc-1 176 if self.exc: -1 177 raise self.exc 167 178 -1 179 -1 180 async def gather(coros: list[Coro[T]]) -> list[T]: -1 181 tg = TaskGroup() -1 182 tasks = [tg.add_task(coro) for coro in coros] -1 183 await tg 168 184 return [typing.cast(T, task.result) for task in tasks] 169 185 170 186