Coverage for typed_stream/_impl/_iteration_utils.py: 98%
176 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-12 21:24 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-12 21:24 +0000
1# Licensed under the EUPL-1.2 or later.
2# You may obtain a copy of the licence in all the official languages of the
3# European Union at https://joinup.ec.europa.eu/collection/eupl/eupl-text-eupl-12
5"""Utility classes used in streams."""
7from __future__ import annotations
9import collections
10import contextlib
11import itertools
12from collections.abc import Callable, Iterable, Iterator
13from typing import Generic, Literal, TypeVar, cast, overload
15from ..streamable import Streamable
16from ._types import ClassWithCleanUp, IteratorProxy, PrettyRepr
17from ._typing import Self, override
18from ._utils import (
19 FunctionWrapperIgnoringArgs,
20 IndexValueTuple,
21 count_required_positional_arguments,
22 wrap_in_tuple,
23)
24from .functions import one
26__all__ = (
27 "Chunked",
28 "Enumerator",
29 "ExceptionHandler",
30 "IfElseMap",
31 "IterWithCleanUp",
32 "Peeker",
33 "count",
34 "sliding_window",
35)
37T = TypeVar("T")
38U = TypeVar("U")
39V = TypeVar("V")
41Exc = TypeVar("Exc", bound=BaseException)
44def count(it: Iterable[object]) -> int:
45 """Count the number of items in the iterable."""
46 return sum(map(one, it))
49class Chunked(
50 IteratorProxy[tuple[T, ...], T],
51 Streamable[tuple[T, ...]],
52 Generic[T],
53):
54 """Chunk data into Sequences of length size. The last chunk may be shorter.
56 Inspired by batched from:
57 https://docs.python.org/3/library/itertools.html?highlight=callable#itertools-recipes
59 >>> chunks = Chunked("abcd", 2)
60 >>> assert "Chunked" in repr(chunks)
61 >>> assert "2" in repr(chunks)
62 >>> list(chunks)
63 [('a', 'b'), ('c', 'd')]
64 """
66 chunk_size: int
68 __slots__ = ("chunk_size",)
70 def __init__(self, iterable: Iterable[T], chunk_size: int) -> None:
71 """Chunk data into Sequences of length chunk_size."""
72 if chunk_size < 1:
73 raise ValueError("size must be at least one")
74 super().__init__(iterable)
75 self.chunk_size = chunk_size
77 @override
78 def __next__(self) -> tuple[T, ...]:
79 """Get the next chunk."""
80 if chunk := tuple(itertools.islice(self._iterator, self.chunk_size)):
81 return chunk
82 raise StopIteration()
84 @override
85 def _get_args(self) -> tuple[object, ...]:
86 """Return the args used to initializing self."""
87 return *super()._get_args(), self.chunk_size
90class Enumerator(IteratorProxy[IndexValueTuple[T], T], Generic[T]):
91 """Like enumerate() but yielding IndexValueTuples."""
93 _curr_idx: int
95 __slots__ = ("_curr_idx",)
97 def __init__(self, iterable: Iterable[T], start_index: int) -> None:
98 """Like enumerate() but yielding IndexValueTuples."""
99 super().__init__(iterable)
100 self._curr_idx = start_index
102 @override
103 def __next__(self: Enumerator[T]) -> IndexValueTuple[T]:
104 """Return the next IndexValueTuple."""
105 tuple_: tuple[int, T] = (self._curr_idx, next(self._iterator))
106 self._curr_idx += 1
107 return IndexValueTuple(tuple_)
109 @override
110 def _get_args(self) -> tuple[object, ...]:
111 """Return the args used to initializing self."""
112 return *super()._get_args(), self._curr_idx
115class ExceptionHandler(IteratorProxy[T | U, T], Generic[T, U, Exc]):
116 """Handle Exceptions in iterators."""
118 _exception_class: type[Exc] | tuple[type[Exc], ...]
119 _default_fun: Callable[[Exc], U] | None
120 _log_fun: Callable[[Exc], object] | None
122 __slots__ = ("_exception_class", "_default_fun", "_log_fun")
124 def __init__(
125 self,
126 iterable: Iterable[T],
127 exception_class: type[Exc] | tuple[type[Exc], ...],
128 log_callable: Callable[[Exc], object] | None = None,
129 default_factory: Callable[[Exc], U] | Callable[[], U] | None = None,
130 ) -> None:
131 """Handle Exceptions in iterables."""
132 super().__init__(iterable)
133 if (
134 (StopIteration in exception_class)
135 if isinstance(exception_class, tuple)
136 else (exception_class == StopIteration)
137 ):
138 raise ValueError("Cannot catch StopIteration")
139 self._exception_class = exception_class
140 self._log_fun = log_callable
141 if default_factory is not None:
142 def_fun = default_factory
143 if not count_required_positional_arguments(def_fun):
144 self._default_fun = FunctionWrapperIgnoringArgs(
145 cast(Callable[[], U], def_fun)
146 )
147 else:
148 self._default_fun = cast(Callable[[Exc], U], def_fun)
149 else:
150 self._default_fun = None
152 @override
153 def __next__(self: ExceptionHandler[T, U, Exc]) -> T | U: # noqa: C901
154 """Return the next value."""
155 while True: # pylint: disable=while-used
156 try:
157 value: T = next(self._iterator)
158 except StopIteration:
159 raise
160 except self._exception_class as exc:
161 if self._log_fun:
162 self._log_fun(exc)
163 if self._default_fun:
164 return self._default_fun(exc)
165 # if no default fun is available just return the next element
166 else:
167 return value
169 @override
170 def _get_args(self) -> tuple[object, ...]:
171 """Return the args used to initializing self."""
172 return (
173 *super()._get_args(),
174 self._exception_class,
175 self._log_fun,
176 self._default_fun,
177 )
180class IfElseMap(IteratorProxy[U | V, T], Generic[T, U, V]):
181 """Map combined with conditions."""
183 _condition: Callable[[T], bool | object]
184 _if_fun: Callable[[T], U]
185 _else_fun: Callable[[T], V] | None
187 __slots__ = ("_condition", "_if_fun", "_else_fun")
189 def __init__(
190 self,
191 iterable: Iterable[T],
192 condition: Callable[[T], bool | object],
193 if_: Callable[[T], U],
194 else_: Callable[[T], V] | None = None,
195 ) -> None:
196 """Map values depending on a condition.
198 Equivalent pairs:
199 - map(lambda _: (if_(_) if condition(_) else else_(_)), iterable)
200 - IfElseMap(iterable, condition, if_, else_)
202 - filter(callable, iterable)
203 - IfElseMap(iterable, callable, lambda _: _, None)
204 """
205 super().__init__(iterable)
206 self._condition = condition
207 if if_ is else_ is None:
208 raise ValueError("")
209 self._if_fun = if_
210 self._else_fun = else_
212 @override
213 def __next__(self: IfElseMap[T, U, V]) -> U | V:
214 """Return the next value."""
215 while True: # pylint: disable=while-used
216 value: T = next(self._iterator)
217 if self._condition(value):
218 return self._if_fun(value)
219 if self._else_fun:
220 return self._else_fun(value)
221 # just return the next element
223 @override
224 def _get_args(self) -> tuple[object, ...]:
225 """Return the args used to initializing self."""
226 return (
227 *super()._get_args(),
228 self._condition,
229 self._if_fun,
230 self._else_fun,
231 )
234class Peeker(Generic[T], PrettyRepr):
235 """Peek values."""
237 fun: Callable[[T], object | None]
239 __slots__ = ("fun",)
241 def __init__(self, fun: Callable[[T], object | None]) -> None:
242 """Initialize this class."""
243 self.fun = fun
245 def __call__(self, value: T, /) -> T:
246 """Call fun with value as argument and return value."""
247 self.fun(value)
248 return value
250 @override
251 def _get_args(self) -> tuple[object, ...]:
252 """Return the args used to initializing self."""
253 return (self.fun,)
256class IterWithCleanUp(Iterator[T], ClassWithCleanUp):
257 """An Iterator that calls a clean-up function when finished.
259 The clean-up function is called once in one of the following conditions:
260 - iteration has been completed
261 - .close() gets called
262 - .__del__() gets called
263 - it's used in a context manager and .__exit__() gets called
265 What you shouldn't do (as calling the clean-up function is probably important):
266 - calling next(this) just once
267 - breaking in a for loop iterating over this without closing this
268 - partially iterating over this without closing
269 """
271 iterator: Iterator[T] | None
273 __slots__ = ("iterator",)
275 def __init__(
276 self, iterable: Iterable[T], cleanup_fun: Callable[[], object | None]
277 ) -> None:
278 """Initialize this class."""
279 super().__init__(cleanup_fun)
280 self.iterator = iter(iterable)
282 @override
283 def __iter__(self) -> Self:
284 """Return self."""
285 return self
287 @override
288 def __next__(self) -> T:
289 """Return the next element if available else run close."""
290 if self.iterator is None:
291 self.close()
292 raise StopIteration
293 try:
294 return next(self.iterator)
295 except BaseException:
296 with contextlib.suppress(Exception):
297 self.close()
298 raise
300 @override
301 def _get_args(self) -> tuple[object, ...]:
302 """Return the args used to initializing self."""
303 return *super()._get_args(), self.iterator
305 @override
306 def close(self) -> None:
307 """Run clean-up if not run yet."""
308 super().close()
309 if self.iterator is not None:
310 self.iterator = None
313class SlidingWindow(IteratorProxy[tuple[T, ...], T], Generic[T]):
314 """Return overlapping n-lets from an iterable.
316 Inspired by sliding_window from:
317 https://docs.python.org/3/library/itertools.html#itertools-recipes
318 """
320 _window: collections.deque[T]
322 __slots__ = ("_window",)
324 def __init__(self, iterable: Iterable[T], size: int) -> None:
325 """Initialize self."""
326 if size < 1:
327 raise ValueError("size needs to be a positive integer")
328 super().__init__(iterable)
329 self._window = collections.deque((), maxlen=size)
331 @override
332 def __next__(self: SlidingWindow[T]) -> tuple[T, ...]:
333 """Return the next n item tuple."""
334 if window_space_left := self.size - len(self._window):
335 self._window.extend(
336 itertools.islice(self._iterator, window_space_left)
337 )
338 if len(self._window) < self.size:
339 self._window.clear()
340 raise StopIteration()
341 else:
342 try:
343 self._window.append(next(self._iterator))
344 except StopIteration:
345 self._window.clear()
346 raise
347 return tuple(self._window)
349 @override
350 def _get_args(self) -> tuple[object, ...]:
351 """Return the args used to initializing self."""
352 return *super()._get_args(), self.size
354 @property
355 def size(self) -> int:
356 """Return the size of the sliding window."""
357 return cast(int, self._window.maxlen)
360@overload
361def sliding_window(
362 iterable: Iterable[T], size: Literal[1]
363) -> Iterator[tuple[T]]: # pragma: no cover
364 ...
367@overload
368def sliding_window(
369 iterable: Iterable[T], size: Literal[2]
370) -> Iterator[tuple[T, T]]: # pragma: no cover
371 ...
374@overload
375def sliding_window(
376 iterable: Iterable[T], size: int
377) -> Iterator[tuple[T, ...]]: # pragma: no cover
378 ...
381def sliding_window(iterable: Iterable[T], size: int) -> Iterator[tuple[T, ...]]:
382 """Return overlapping size-lets from an iterable.
384 If len(iterable) < size then an empty iterator is returned.
385 """
386 if size == 1:
387 return map(wrap_in_tuple, iterable)
388 if size == 2:
389 return itertools.pairwise(iterable)
390 return SlidingWindow(iterable, size)