Coverage for typed_stream/_impl/_iteration_utils.py: 98%

176 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-12 21:24 +0000

1# Licensed under the EUPL-1.2 or later. 

2# You may obtain a copy of the licence in all the official languages of the 

3# European Union at https://joinup.ec.europa.eu/collection/eupl/eupl-text-eupl-12 

4 

5"""Utility classes used in streams.""" 

6 

7from __future__ import annotations 

8 

9import collections 

10import contextlib 

11import itertools 

12from collections.abc import Callable, Iterable, Iterator 

13from typing import Generic, Literal, TypeVar, cast, overload 

14 

15from ..streamable import Streamable 

16from ._types import ClassWithCleanUp, IteratorProxy, PrettyRepr 

17from ._typing import Self, override 

18from ._utils import ( 

19 FunctionWrapperIgnoringArgs, 

20 IndexValueTuple, 

21 count_required_positional_arguments, 

22 wrap_in_tuple, 

23) 

24from .functions import one 

25 

26__all__ = ( 

27 "Chunked", 

28 "Enumerator", 

29 "ExceptionHandler", 

30 "IfElseMap", 

31 "IterWithCleanUp", 

32 "Peeker", 

33 "count", 

34 "sliding_window", 

35) 

36 

37T = TypeVar("T") 

38U = TypeVar("U") 

39V = TypeVar("V") 

40 

41Exc = TypeVar("Exc", bound=BaseException) 

42 

43 

44def count(it: Iterable[object]) -> int: 

45 """Count the number of items in the iterable.""" 

46 return sum(map(one, it)) 

47 

48 

49class Chunked( 

50 IteratorProxy[tuple[T, ...], T], 

51 Streamable[tuple[T, ...]], 

52 Generic[T], 

53): 

54 """Chunk data into Sequences of length size. The last chunk may be shorter. 

55 

56 Inspired by batched from: 

57 https://docs.python.org/3/library/itertools.html?highlight=callable#itertools-recipes 

58 

59 >>> chunks = Chunked("abcd", 2) 

60 >>> assert "Chunked" in repr(chunks) 

61 >>> assert "2" in repr(chunks) 

62 >>> list(chunks) 

63 [('a', 'b'), ('c', 'd')] 

64 """ 

65 

66 chunk_size: int 

67 

68 __slots__ = ("chunk_size",) 

69 

70 def __init__(self, iterable: Iterable[T], chunk_size: int) -> None: 

71 """Chunk data into Sequences of length chunk_size.""" 

72 if chunk_size < 1: 

73 raise ValueError("size must be at least one") 

74 super().__init__(iterable) 

75 self.chunk_size = chunk_size 

76 

77 @override 

78 def __next__(self) -> tuple[T, ...]: 

79 """Get the next chunk.""" 

80 if chunk := tuple(itertools.islice(self._iterator, self.chunk_size)): 

81 return chunk 

82 raise StopIteration() 

83 

84 @override 

85 def _get_args(self) -> tuple[object, ...]: 

86 """Return the args used to initializing self.""" 

87 return *super()._get_args(), self.chunk_size 

88 

89 

90class Enumerator(IteratorProxy[IndexValueTuple[T], T], Generic[T]): 

91 """Like enumerate() but yielding IndexValueTuples.""" 

92 

93 _curr_idx: int 

94 

95 __slots__ = ("_curr_idx",) 

96 

97 def __init__(self, iterable: Iterable[T], start_index: int) -> None: 

98 """Like enumerate() but yielding IndexValueTuples.""" 

99 super().__init__(iterable) 

100 self._curr_idx = start_index 

101 

102 @override 

103 def __next__(self: Enumerator[T]) -> IndexValueTuple[T]: 

104 """Return the next IndexValueTuple.""" 

105 tuple_: tuple[int, T] = (self._curr_idx, next(self._iterator)) 

106 self._curr_idx += 1 

107 return IndexValueTuple(tuple_) 

108 

109 @override 

110 def _get_args(self) -> tuple[object, ...]: 

111 """Return the args used to initializing self.""" 

112 return *super()._get_args(), self._curr_idx 

113 

114 

115class ExceptionHandler(IteratorProxy[T | U, T], Generic[T, U, Exc]): 

116 """Handle Exceptions in iterators.""" 

117 

118 _exception_class: type[Exc] | tuple[type[Exc], ...] 

119 _default_fun: Callable[[Exc], U] | None 

120 _log_fun: Callable[[Exc], object] | None 

121 

122 __slots__ = ("_exception_class", "_default_fun", "_log_fun") 

123 

124 def __init__( 

125 self, 

126 iterable: Iterable[T], 

127 exception_class: type[Exc] | tuple[type[Exc], ...], 

128 log_callable: Callable[[Exc], object] | None = None, 

129 default_factory: Callable[[Exc], U] | Callable[[], U] | None = None, 

130 ) -> None: 

131 """Handle Exceptions in iterables.""" 

132 super().__init__(iterable) 

133 if ( 

134 (StopIteration in exception_class) 

135 if isinstance(exception_class, tuple) 

136 else (exception_class == StopIteration) 

137 ): 

138 raise ValueError("Cannot catch StopIteration") 

139 self._exception_class = exception_class 

140 self._log_fun = log_callable 

141 if default_factory is not None: 

142 def_fun = default_factory 

143 if not count_required_positional_arguments(def_fun): 

144 self._default_fun = FunctionWrapperIgnoringArgs( 

145 cast(Callable[[], U], def_fun) 

146 ) 

147 else: 

148 self._default_fun = cast(Callable[[Exc], U], def_fun) 

149 else: 

150 self._default_fun = None 

151 

152 @override 

153 def __next__(self: ExceptionHandler[T, U, Exc]) -> T | U: # noqa: C901 

154 """Return the next value.""" 

155 while True: # pylint: disable=while-used 

156 try: 

157 value: T = next(self._iterator) 

158 except StopIteration: 

159 raise 

160 except self._exception_class as exc: 

161 if self._log_fun: 

162 self._log_fun(exc) 

163 if self._default_fun: 

164 return self._default_fun(exc) 

165 # if no default fun is available just return the next element 

166 else: 

167 return value 

168 

169 @override 

170 def _get_args(self) -> tuple[object, ...]: 

171 """Return the args used to initializing self.""" 

172 return ( 

173 *super()._get_args(), 

174 self._exception_class, 

175 self._log_fun, 

176 self._default_fun, 

177 ) 

178 

179 

180class IfElseMap(IteratorProxy[U | V, T], Generic[T, U, V]): 

181 """Map combined with conditions.""" 

182 

183 _condition: Callable[[T], bool | object] 

184 _if_fun: Callable[[T], U] 

185 _else_fun: Callable[[T], V] | None 

186 

187 __slots__ = ("_condition", "_if_fun", "_else_fun") 

188 

189 def __init__( 

190 self, 

191 iterable: Iterable[T], 

192 condition: Callable[[T], bool | object], 

193 if_: Callable[[T], U], 

194 else_: Callable[[T], V] | None = None, 

195 ) -> None: 

196 """Map values depending on a condition. 

197 

198 Equivalent pairs: 

199 - map(lambda _: (if_(_) if condition(_) else else_(_)), iterable) 

200 - IfElseMap(iterable, condition, if_, else_) 

201 

202 - filter(callable, iterable) 

203 - IfElseMap(iterable, callable, lambda _: _, None) 

204 """ 

205 super().__init__(iterable) 

206 self._condition = condition 

207 if if_ is else_ is None: 

208 raise ValueError("") 

209 self._if_fun = if_ 

210 self._else_fun = else_ 

211 

212 @override 

213 def __next__(self: IfElseMap[T, U, V]) -> U | V: 

214 """Return the next value.""" 

215 while True: # pylint: disable=while-used 

216 value: T = next(self._iterator) 

217 if self._condition(value): 

218 return self._if_fun(value) 

219 if self._else_fun: 

220 return self._else_fun(value) 

221 # just return the next element 

222 

223 @override 

224 def _get_args(self) -> tuple[object, ...]: 

225 """Return the args used to initializing self.""" 

226 return ( 

227 *super()._get_args(), 

228 self._condition, 

229 self._if_fun, 

230 self._else_fun, 

231 ) 

232 

233 

234class Peeker(Generic[T], PrettyRepr): 

235 """Peek values.""" 

236 

237 fun: Callable[[T], object | None] 

238 

239 __slots__ = ("fun",) 

240 

241 def __init__(self, fun: Callable[[T], object | None]) -> None: 

242 """Initialize this class.""" 

243 self.fun = fun 

244 

245 def __call__(self, value: T, /) -> T: 

246 """Call fun with value as argument and return value.""" 

247 self.fun(value) 

248 return value 

249 

250 @override 

251 def _get_args(self) -> tuple[object, ...]: 

252 """Return the args used to initializing self.""" 

253 return (self.fun,) 

254 

255 

256class IterWithCleanUp(Iterator[T], ClassWithCleanUp): 

257 """An Iterator that calls a clean-up function when finished. 

258 

259 The clean-up function is called once in one of the following conditions: 

260 - iteration has been completed 

261 - .close() gets called 

262 - .__del__() gets called 

263 - it's used in a context manager and .__exit__() gets called 

264 

265 What you shouldn't do (as calling the clean-up function is probably important): 

266 - calling next(this) just once 

267 - breaking in a for loop iterating over this without closing this 

268 - partially iterating over this without closing 

269 """ 

270 

271 iterator: Iterator[T] | None 

272 

273 __slots__ = ("iterator",) 

274 

275 def __init__( 

276 self, iterable: Iterable[T], cleanup_fun: Callable[[], object | None] 

277 ) -> None: 

278 """Initialize this class.""" 

279 super().__init__(cleanup_fun) 

280 self.iterator = iter(iterable) 

281 

282 @override 

283 def __iter__(self) -> Self: 

284 """Return self.""" 

285 return self 

286 

287 @override 

288 def __next__(self) -> T: 

289 """Return the next element if available else run close.""" 

290 if self.iterator is None: 

291 self.close() 

292 raise StopIteration 

293 try: 

294 return next(self.iterator) 

295 except BaseException: 

296 with contextlib.suppress(Exception): 

297 self.close() 

298 raise 

299 

300 @override 

301 def _get_args(self) -> tuple[object, ...]: 

302 """Return the args used to initializing self.""" 

303 return *super()._get_args(), self.iterator 

304 

305 @override 

306 def close(self) -> None: 

307 """Run clean-up if not run yet.""" 

308 super().close() 

309 if self.iterator is not None: 

310 self.iterator = None 

311 

312 

313class SlidingWindow(IteratorProxy[tuple[T, ...], T], Generic[T]): 

314 """Return overlapping n-lets from an iterable. 

315 

316 Inspired by sliding_window from: 

317 https://docs.python.org/3/library/itertools.html#itertools-recipes 

318 """ 

319 

320 _window: collections.deque[T] 

321 

322 __slots__ = ("_window",) 

323 

324 def __init__(self, iterable: Iterable[T], size: int) -> None: 

325 """Initialize self.""" 

326 if size < 1: 

327 raise ValueError("size needs to be a positive integer") 

328 super().__init__(iterable) 

329 self._window = collections.deque((), maxlen=size) 

330 

331 @override 

332 def __next__(self: SlidingWindow[T]) -> tuple[T, ...]: 

333 """Return the next n item tuple.""" 

334 if window_space_left := self.size - len(self._window): 

335 self._window.extend( 

336 itertools.islice(self._iterator, window_space_left) 

337 ) 

338 if len(self._window) < self.size: 

339 self._window.clear() 

340 raise StopIteration() 

341 else: 

342 try: 

343 self._window.append(next(self._iterator)) 

344 except StopIteration: 

345 self._window.clear() 

346 raise 

347 return tuple(self._window) 

348 

349 @override 

350 def _get_args(self) -> tuple[object, ...]: 

351 """Return the args used to initializing self.""" 

352 return *super()._get_args(), self.size 

353 

354 @property 

355 def size(self) -> int: 

356 """Return the size of the sliding window.""" 

357 return cast(int, self._window.maxlen) 

358 

359 

360@overload 

361def sliding_window( 

362 iterable: Iterable[T], size: Literal[1] 

363) -> Iterator[tuple[T]]: # pragma: no cover 

364 ... 

365 

366 

367@overload 

368def sliding_window( 

369 iterable: Iterable[T], size: Literal[2] 

370) -> Iterator[tuple[T, T]]: # pragma: no cover 

371 ... 

372 

373 

374@overload 

375def sliding_window( 

376 iterable: Iterable[T], size: int 

377) -> Iterator[tuple[T, ...]]: # pragma: no cover 

378 ... 

379 

380 

381def sliding_window(iterable: Iterable[T], size: int) -> Iterator[tuple[T, ...]]: 

382 """Return overlapping size-lets from an iterable. 

383 

384 If len(iterable) < size then an empty iterator is returned. 

385 """ 

386 if size == 1: 

387 return map(wrap_in_tuple, iterable) 

388 if size == 2: 

389 return itertools.pairwise(iterable) 

390 return SlidingWindow(iterable, size)