Coverage for typed_stream/_impl/_iteration_utils.py: 98%

176 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-23 18:47 +0000

1# Licensed under the EUPL-1.2 or later. 

2# You may obtain a copy of the licence in all the official languages of the 

3# European Union at https://joinup.ec.europa.eu/collection/eupl/eupl-text-eupl-12 

4 

5"""Utility classes used in streams.""" 

6 

7from __future__ import annotations 

8 

9import collections 

10import contextlib 

11import itertools 

12from collections.abc import Callable, Iterable, Iterator 

13from typing import Generic, Literal, TypeVar, Union, cast, overload 

14 

15from ..streamable import Streamable 

16from ._types import ClassWithCleanUp, IteratorProxy, PrettyRepr 

17from ._typing import Self, override 

18from ._utils import ( 

19 FunctionWrapperIgnoringArgs, 

20 IndexValueTuple, 

21 count_required_positional_arguments, 

22 wrap_in_tuple, 

23) 

24from .functions import one 

25 

26__all__ = ( 

27 "Chunked", 

28 "Enumerator", 

29 "ExceptionHandler", 

30 "IfElseMap", 

31 "IterWithCleanUp", 

32 "Peeker", 

33 "count", 

34 "sliding_window", 

35) 

36 

37T = TypeVar("T") 

38U = TypeVar("U") 

39V = TypeVar("V") 

40 

41Exc = TypeVar("Exc", bound=BaseException) 

42 

43 

44def count(it: Iterable[object]) -> int: 

45 """Count the number of items in the iterable.""" 

46 return sum(map(one, it)) 

47 

48 

49class Chunked( 

50 IteratorProxy[tuple[T, ...], T], 

51 Streamable[tuple[T, ...]], 

52 Generic[T], 

53): 

54 """Chunk data into Sequences of length size. The last chunk may be shorter. 

55 

56 Inspired by batched from: 

57 https://docs.python.org/3/library/itertools.html?highlight=callable#itertools-recipes 

58 

59 >>> chunks = Chunked("abcd", 2) 

60 >>> assert "Chunked" in repr(chunks) 

61 >>> assert "2" in repr(chunks) 

62 >>> list(chunks) 

63 [('a', 'b'), ('c', 'd')] 

64 """ 

65 

66 chunk_size: int 

67 

68 __slots__ = ("chunk_size",) 

69 

70 def __init__(self, iterable: Iterable[T], chunk_size: int) -> None: 

71 """Chunk data into Sequences of length chunk_size.""" 

72 if chunk_size < 1: 

73 raise ValueError("size must be at least one") 

74 super().__init__(iterable) 

75 self.chunk_size = chunk_size 

76 

77 @override 

78 def __next__(self) -> tuple[T, ...]: 

79 """Get the next chunk.""" 

80 if chunk := tuple(itertools.islice(self._iterator, self.chunk_size)): 

81 return chunk 

82 raise StopIteration() 

83 

84 @override 

85 def _get_args(self) -> tuple[object, ...]: 

86 """Return the args used to initializing self.""" 

87 return *super()._get_args(), self.chunk_size 

88 

89 

90class Enumerator(IteratorProxy[IndexValueTuple[T], T], Generic[T]): 

91 """Like enumerate() but yielding IndexValueTuples.""" 

92 

93 _curr_idx: int 

94 

95 __slots__ = ("_curr_idx",) 

96 

97 def __init__(self, iterable: Iterable[T], start_index: int) -> None: 

98 """Like enumerate() but yielding IndexValueTuples.""" 

99 super().__init__(iterable) 

100 self._curr_idx = start_index 

101 

102 @override 

103 def __next__(self: Enumerator[T]) -> IndexValueTuple[T]: 

104 """Return the next IndexValueTuple.""" 

105 tuple_: tuple[int, T] = (self._curr_idx, next(self._iterator)) 

106 self._curr_idx += 1 

107 return IndexValueTuple(tuple_) 

108 

109 @override 

110 def _get_args(self) -> tuple[object, ...]: 

111 """Return the args used to initializing self.""" 

112 return *super()._get_args(), self._curr_idx 

113 

114 

115# pylint: disable-next=consider-alternative-union-syntax 

116class ExceptionHandler(IteratorProxy[Union[T, U], T], Generic[T, U, Exc]): 

117 """Handle Exceptions in iterators.""" 

118 

119 _exception_class: type[Exc] | tuple[type[Exc], ...] 

120 _default_fun: Callable[[Exc], U] | None 

121 _log_fun: Callable[[Exc], object] | None 

122 

123 __slots__ = ("_exception_class", "_default_fun", "_log_fun") 

124 

125 def __init__( 

126 self, 

127 iterable: Iterable[T], 

128 exception_class: type[Exc] | tuple[type[Exc], ...], 

129 log_callable: Callable[[Exc], object] | None = None, 

130 default_factory: Callable[[Exc], U] | Callable[[], U] | None = None, 

131 ) -> None: 

132 """Handle Exceptions in iterables.""" 

133 super().__init__(iterable) 

134 if ( 

135 (StopIteration in exception_class) 

136 if isinstance(exception_class, tuple) 

137 else (exception_class == StopIteration) 

138 ): 

139 raise ValueError("Cannot catch StopIteration") 

140 self._exception_class = exception_class 

141 self._log_fun = log_callable 

142 if default_factory is not None: 

143 def_fun = default_factory 

144 if not count_required_positional_arguments(def_fun): 

145 self._default_fun = FunctionWrapperIgnoringArgs( 

146 cast(Callable[[], U], def_fun) 

147 ) 

148 else: 

149 self._default_fun = cast(Callable[[Exc], U], def_fun) 

150 else: 

151 self._default_fun = None 

152 

153 @override 

154 def __next__(self: ExceptionHandler[T, U, Exc]) -> T | U: # noqa: C901 

155 """Return the next value.""" 

156 while True: # pylint: disable=while-used 

157 try: 

158 value: T = next(self._iterator) 

159 except StopIteration: 

160 raise 

161 except self._exception_class as exc: 

162 if self._log_fun: 

163 self._log_fun(exc) 

164 if self._default_fun: 

165 return self._default_fun(exc) 

166 # if no default fun is available just return the next element 

167 else: 

168 return value 

169 

170 @override 

171 def _get_args(self) -> tuple[object, ...]: 

172 """Return the args used to initializing self.""" 

173 return ( 

174 *super()._get_args(), 

175 self._exception_class, 

176 self._log_fun, 

177 self._default_fun, 

178 ) 

179 

180 

181# pylint: disable-next=consider-alternative-union-syntax 

182class IfElseMap(IteratorProxy[Union[U, V], T], Generic[T, U, V]): 

183 """Map combined with conditions.""" 

184 

185 _condition: Callable[[T], bool | object] 

186 _if_fun: Callable[[T], U] 

187 _else_fun: Callable[[T], V] | None 

188 

189 __slots__ = ("_condition", "_if_fun", "_else_fun") 

190 

191 def __init__( 

192 self, 

193 iterable: Iterable[T], 

194 condition: Callable[[T], bool | object], 

195 if_: Callable[[T], U], 

196 else_: Callable[[T], V] | None = None, 

197 ) -> None: 

198 """Map values depending on a condition. 

199 

200 Equivalent pairs: 

201 - map(lambda _: (if_(_) if condition(_) else else_(_)), iterable) 

202 - IfElseMap(iterable, condition, if_, else_) 

203 

204 - filter(callable, iterable) 

205 - IfElseMap(iterable, callable, lambda _: _, None) 

206 """ 

207 super().__init__(iterable) 

208 self._condition = condition 

209 if if_ is else_ is None: 

210 raise ValueError("") 

211 self._if_fun = if_ 

212 self._else_fun = else_ 

213 

214 @override 

215 def __next__(self: IfElseMap[T, U, V]) -> U | V: 

216 """Return the next value.""" 

217 while True: # pylint: disable=while-used 

218 value: T = next(self._iterator) 

219 if self._condition(value): 

220 return self._if_fun(value) 

221 if self._else_fun: 

222 return self._else_fun(value) 

223 # just return the next element 

224 

225 @override 

226 def _get_args(self) -> tuple[object, ...]: 

227 """Return the args used to initializing self.""" 

228 return ( 

229 *super()._get_args(), 

230 self._condition, 

231 self._if_fun, 

232 self._else_fun, 

233 ) 

234 

235 

236class Peeker(Generic[T], PrettyRepr): 

237 """Peek values.""" 

238 

239 fun: Callable[[T], object | None] 

240 

241 __slots__ = ("fun",) 

242 

243 def __init__(self, fun: Callable[[T], object | None]) -> None: 

244 """Initialize this class.""" 

245 self.fun = fun 

246 

247 def __call__(self, value: T, /) -> T: 

248 """Call fun with value as argument and return value.""" 

249 self.fun(value) 

250 return value 

251 

252 @override 

253 def _get_args(self) -> tuple[object, ...]: 

254 """Return the args used to initializing self.""" 

255 return (self.fun,) 

256 

257 

258class IterWithCleanUp(Iterator[T], ClassWithCleanUp): 

259 """An Iterator that calls a clean-up function when finished. 

260 

261 The clean-up function is called once in one of the following conditions: 

262 - iteration has been completed 

263 - .close() gets called 

264 - .__del__() gets called 

265 - it's used in a context manager and .__exit__() gets called 

266 

267 What you shouldn't do (as calling the clean-up function is probably important): 

268 - calling next(this) just once 

269 - breaking in a for loop iterating over this without closing this 

270 - partially iterating over this without closing 

271 """ 

272 

273 iterator: Iterator[T] | None 

274 

275 __slots__ = ("iterator",) 

276 

277 def __init__( 

278 self, iterable: Iterable[T], cleanup_fun: Callable[[], object | None] 

279 ) -> None: 

280 """Initialize this class.""" 

281 super().__init__(cleanup_fun) 

282 self.iterator = iter(iterable) 

283 

284 @override 

285 def __iter__(self) -> Self: 

286 """Return self.""" 

287 return self 

288 

289 @override 

290 def __next__(self) -> T: 

291 """Return the next element if available else run close.""" 

292 if self.iterator is None: 

293 self.close() 

294 raise StopIteration 

295 try: 

296 return next(self.iterator) 

297 except BaseException: 

298 with contextlib.suppress(Exception): 

299 self.close() 

300 raise 

301 

302 @override 

303 def _get_args(self) -> tuple[object, ...]: 

304 """Return the args used to initializing self.""" 

305 return *super()._get_args(), self.iterator 

306 

307 @override 

308 def close(self) -> None: 

309 """Run clean-up if not run yet.""" 

310 super().close() 

311 if self.iterator is not None: 

312 self.iterator = None 

313 

314 

315class SlidingWindow(IteratorProxy[tuple[T, ...], T], Generic[T]): 

316 """Return overlapping n-lets from an iterable. 

317 

318 Inspired by sliding_window from: 

319 https://docs.python.org/3/library/itertools.html#itertools-recipes 

320 """ 

321 

322 _window: collections.deque[T] 

323 

324 __slots__ = ("_window",) 

325 

326 def __init__(self, iterable: Iterable[T], size: int) -> None: 

327 """Initialize self.""" 

328 if size < 1: 

329 raise ValueError("size needs to be a positive integer") 

330 super().__init__(iterable) 

331 self._window = collections.deque((), maxlen=size) 

332 

333 @override 

334 def __next__(self: SlidingWindow[T]) -> tuple[T, ...]: 

335 """Return the next n item tuple.""" 

336 if window_space_left := self.size - len(self._window): 

337 self._window.extend( 

338 itertools.islice(self._iterator, window_space_left) 

339 ) 

340 if len(self._window) < self.size: 

341 self._window.clear() 

342 raise StopIteration() 

343 else: 

344 try: 

345 self._window.append(next(self._iterator)) 

346 except StopIteration: 

347 self._window.clear() 

348 raise 

349 return tuple(self._window) 

350 

351 @override 

352 def _get_args(self) -> tuple[object, ...]: 

353 """Return the args used to initializing self.""" 

354 return *super()._get_args(), self.size 

355 

356 @property 

357 def size(self) -> int: 

358 """Return the size of the sliding window.""" 

359 return cast(int, self._window.maxlen) 

360 

361 

362@overload 

363def sliding_window( 

364 iterable: Iterable[T], size: Literal[1] 

365) -> Iterator[tuple[T]]: # pragma: no cover 

366 ... 

367 

368 

369@overload 

370def sliding_window( 

371 iterable: Iterable[T], size: Literal[2] 

372) -> Iterator[tuple[T, T]]: # pragma: no cover 

373 ... 

374 

375 

376@overload 

377def sliding_window( 

378 iterable: Iterable[T], size: int 

379) -> Iterator[tuple[T, ...]]: # pragma: no cover 

380 ... 

381 

382 

383def sliding_window(iterable: Iterable[T], size: int) -> Iterator[tuple[T, ...]]: 

384 """Return overlapping size-lets from an iterable. 

385 

386 If len(iterable) < size then an empty iterator is returned. 

387 """ 

388 if size == 1: 

389 return map(wrap_in_tuple, iterable) 

390 if size == 2: 

391 return itertools.pairwise(iterable) 

392 return SlidingWindow(iterable, size)