Coverage for typed_stream/_impl/_iteration_utils.py: 98%
176 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-23 18:47 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-23 18:47 +0000
1# Licensed under the EUPL-1.2 or later.
2# You may obtain a copy of the licence in all the official languages of the
3# European Union at https://joinup.ec.europa.eu/collection/eupl/eupl-text-eupl-12
5"""Utility classes used in streams."""
7from __future__ import annotations
9import collections
10import contextlib
11import itertools
12from collections.abc import Callable, Iterable, Iterator
13from typing import Generic, Literal, TypeVar, Union, cast, overload
15from ..streamable import Streamable
16from ._types import ClassWithCleanUp, IteratorProxy, PrettyRepr
17from ._typing import Self, override
18from ._utils import (
19 FunctionWrapperIgnoringArgs,
20 IndexValueTuple,
21 count_required_positional_arguments,
22 wrap_in_tuple,
23)
24from .functions import one
26__all__ = (
27 "Chunked",
28 "Enumerator",
29 "ExceptionHandler",
30 "IfElseMap",
31 "IterWithCleanUp",
32 "Peeker",
33 "count",
34 "sliding_window",
35)
37T = TypeVar("T")
38U = TypeVar("U")
39V = TypeVar("V")
41Exc = TypeVar("Exc", bound=BaseException)
44def count(it: Iterable[object]) -> int:
45 """Count the number of items in the iterable."""
46 return sum(map(one, it))
49class Chunked(
50 IteratorProxy[tuple[T, ...], T],
51 Streamable[tuple[T, ...]],
52 Generic[T],
53):
54 """Chunk data into Sequences of length size. The last chunk may be shorter.
56 Inspired by batched from:
57 https://docs.python.org/3/library/itertools.html?highlight=callable#itertools-recipes
59 >>> chunks = Chunked("abcd", 2)
60 >>> assert "Chunked" in repr(chunks)
61 >>> assert "2" in repr(chunks)
62 >>> list(chunks)
63 [('a', 'b'), ('c', 'd')]
64 """
66 chunk_size: int
68 __slots__ = ("chunk_size",)
70 def __init__(self, iterable: Iterable[T], chunk_size: int) -> None:
71 """Chunk data into Sequences of length chunk_size."""
72 if chunk_size < 1:
73 raise ValueError("size must be at least one")
74 super().__init__(iterable)
75 self.chunk_size = chunk_size
77 @override
78 def __next__(self) -> tuple[T, ...]:
79 """Get the next chunk."""
80 if chunk := tuple(itertools.islice(self._iterator, self.chunk_size)):
81 return chunk
82 raise StopIteration()
84 @override
85 def _get_args(self) -> tuple[object, ...]:
86 """Return the args used to initializing self."""
87 return *super()._get_args(), self.chunk_size
90class Enumerator(IteratorProxy[IndexValueTuple[T], T], Generic[T]):
91 """Like enumerate() but yielding IndexValueTuples."""
93 _curr_idx: int
95 __slots__ = ("_curr_idx",)
97 def __init__(self, iterable: Iterable[T], start_index: int) -> None:
98 """Like enumerate() but yielding IndexValueTuples."""
99 super().__init__(iterable)
100 self._curr_idx = start_index
102 @override
103 def __next__(self: Enumerator[T]) -> IndexValueTuple[T]:
104 """Return the next IndexValueTuple."""
105 tuple_: tuple[int, T] = (self._curr_idx, next(self._iterator))
106 self._curr_idx += 1
107 return IndexValueTuple(tuple_)
109 @override
110 def _get_args(self) -> tuple[object, ...]:
111 """Return the args used to initializing self."""
112 return *super()._get_args(), self._curr_idx
115# pylint: disable-next=consider-alternative-union-syntax
116class ExceptionHandler(IteratorProxy[Union[T, U], T], Generic[T, U, Exc]):
117 """Handle Exceptions in iterators."""
119 _exception_class: type[Exc] | tuple[type[Exc], ...]
120 _default_fun: Callable[[Exc], U] | None
121 _log_fun: Callable[[Exc], object] | None
123 __slots__ = ("_exception_class", "_default_fun", "_log_fun")
125 def __init__(
126 self,
127 iterable: Iterable[T],
128 exception_class: type[Exc] | tuple[type[Exc], ...],
129 log_callable: Callable[[Exc], object] | None = None,
130 default_factory: Callable[[Exc], U] | Callable[[], U] | None = None,
131 ) -> None:
132 """Handle Exceptions in iterables."""
133 super().__init__(iterable)
134 if (
135 (StopIteration in exception_class)
136 if isinstance(exception_class, tuple)
137 else (exception_class == StopIteration)
138 ):
139 raise ValueError("Cannot catch StopIteration")
140 self._exception_class = exception_class
141 self._log_fun = log_callable
142 if default_factory is not None:
143 def_fun = default_factory
144 if not count_required_positional_arguments(def_fun):
145 self._default_fun = FunctionWrapperIgnoringArgs(
146 cast(Callable[[], U], def_fun)
147 )
148 else:
149 self._default_fun = cast(Callable[[Exc], U], def_fun)
150 else:
151 self._default_fun = None
153 @override
154 def __next__(self: ExceptionHandler[T, U, Exc]) -> T | U: # noqa: C901
155 """Return the next value."""
156 while True: # pylint: disable=while-used
157 try:
158 value: T = next(self._iterator)
159 except StopIteration:
160 raise
161 except self._exception_class as exc:
162 if self._log_fun:
163 self._log_fun(exc)
164 if self._default_fun:
165 return self._default_fun(exc)
166 # if no default fun is available just return the next element
167 else:
168 return value
170 @override
171 def _get_args(self) -> tuple[object, ...]:
172 """Return the args used to initializing self."""
173 return (
174 *super()._get_args(),
175 self._exception_class,
176 self._log_fun,
177 self._default_fun,
178 )
181# pylint: disable-next=consider-alternative-union-syntax
182class IfElseMap(IteratorProxy[Union[U, V], T], Generic[T, U, V]):
183 """Map combined with conditions."""
185 _condition: Callable[[T], bool | object]
186 _if_fun: Callable[[T], U]
187 _else_fun: Callable[[T], V] | None
189 __slots__ = ("_condition", "_if_fun", "_else_fun")
191 def __init__(
192 self,
193 iterable: Iterable[T],
194 condition: Callable[[T], bool | object],
195 if_: Callable[[T], U],
196 else_: Callable[[T], V] | None = None,
197 ) -> None:
198 """Map values depending on a condition.
200 Equivalent pairs:
201 - map(lambda _: (if_(_) if condition(_) else else_(_)), iterable)
202 - IfElseMap(iterable, condition, if_, else_)
204 - filter(callable, iterable)
205 - IfElseMap(iterable, callable, lambda _: _, None)
206 """
207 super().__init__(iterable)
208 self._condition = condition
209 if if_ is else_ is None:
210 raise ValueError("")
211 self._if_fun = if_
212 self._else_fun = else_
214 @override
215 def __next__(self: IfElseMap[T, U, V]) -> U | V:
216 """Return the next value."""
217 while True: # pylint: disable=while-used
218 value: T = next(self._iterator)
219 if self._condition(value):
220 return self._if_fun(value)
221 if self._else_fun:
222 return self._else_fun(value)
223 # just return the next element
225 @override
226 def _get_args(self) -> tuple[object, ...]:
227 """Return the args used to initializing self."""
228 return (
229 *super()._get_args(),
230 self._condition,
231 self._if_fun,
232 self._else_fun,
233 )
236class Peeker(Generic[T], PrettyRepr):
237 """Peek values."""
239 fun: Callable[[T], object | None]
241 __slots__ = ("fun",)
243 def __init__(self, fun: Callable[[T], object | None]) -> None:
244 """Initialize this class."""
245 self.fun = fun
247 def __call__(self, value: T, /) -> T:
248 """Call fun with value as argument and return value."""
249 self.fun(value)
250 return value
252 @override
253 def _get_args(self) -> tuple[object, ...]:
254 """Return the args used to initializing self."""
255 return (self.fun,)
258class IterWithCleanUp(Iterator[T], ClassWithCleanUp):
259 """An Iterator that calls a clean-up function when finished.
261 The clean-up function is called once in one of the following conditions:
262 - iteration has been completed
263 - .close() gets called
264 - .__del__() gets called
265 - it's used in a context manager and .__exit__() gets called
267 What you shouldn't do (as calling the clean-up function is probably important):
268 - calling next(this) just once
269 - breaking in a for loop iterating over this without closing this
270 - partially iterating over this without closing
271 """
273 iterator: Iterator[T] | None
275 __slots__ = ("iterator",)
277 def __init__(
278 self, iterable: Iterable[T], cleanup_fun: Callable[[], object | None]
279 ) -> None:
280 """Initialize this class."""
281 super().__init__(cleanup_fun)
282 self.iterator = iter(iterable)
284 @override
285 def __iter__(self) -> Self:
286 """Return self."""
287 return self
289 @override
290 def __next__(self) -> T:
291 """Return the next element if available else run close."""
292 if self.iterator is None:
293 self.close()
294 raise StopIteration
295 try:
296 return next(self.iterator)
297 except BaseException:
298 with contextlib.suppress(Exception):
299 self.close()
300 raise
302 @override
303 def _get_args(self) -> tuple[object, ...]:
304 """Return the args used to initializing self."""
305 return *super()._get_args(), self.iterator
307 @override
308 def close(self) -> None:
309 """Run clean-up if not run yet."""
310 super().close()
311 if self.iterator is not None:
312 self.iterator = None
315class SlidingWindow(IteratorProxy[tuple[T, ...], T], Generic[T]):
316 """Return overlapping n-lets from an iterable.
318 Inspired by sliding_window from:
319 https://docs.python.org/3/library/itertools.html#itertools-recipes
320 """
322 _window: collections.deque[T]
324 __slots__ = ("_window",)
326 def __init__(self, iterable: Iterable[T], size: int) -> None:
327 """Initialize self."""
328 if size < 1:
329 raise ValueError("size needs to be a positive integer")
330 super().__init__(iterable)
331 self._window = collections.deque((), maxlen=size)
333 @override
334 def __next__(self: SlidingWindow[T]) -> tuple[T, ...]:
335 """Return the next n item tuple."""
336 if window_space_left := self.size - len(self._window):
337 self._window.extend(
338 itertools.islice(self._iterator, window_space_left)
339 )
340 if len(self._window) < self.size:
341 self._window.clear()
342 raise StopIteration()
343 else:
344 try:
345 self._window.append(next(self._iterator))
346 except StopIteration:
347 self._window.clear()
348 raise
349 return tuple(self._window)
351 @override
352 def _get_args(self) -> tuple[object, ...]:
353 """Return the args used to initializing self."""
354 return *super()._get_args(), self.size
356 @property
357 def size(self) -> int:
358 """Return the size of the sliding window."""
359 return cast(int, self._window.maxlen)
362@overload
363def sliding_window(
364 iterable: Iterable[T], size: Literal[1]
365) -> Iterator[tuple[T]]: # pragma: no cover
366 ...
369@overload
370def sliding_window(
371 iterable: Iterable[T], size: Literal[2]
372) -> Iterator[tuple[T, T]]: # pragma: no cover
373 ...
376@overload
377def sliding_window(
378 iterable: Iterable[T], size: int
379) -> Iterator[tuple[T, ...]]: # pragma: no cover
380 ...
383def sliding_window(iterable: Iterable[T], size: int) -> Iterator[tuple[T, ...]]:
384 """Return overlapping size-lets from an iterable.
386 If len(iterable) < size then an empty iterator is returned.
387 """
388 if size == 1:
389 return map(wrap_in_tuple, iterable)
390 if size == 2:
391 return itertools.pairwise(iterable)
392 return SlidingWindow(iterable, size)