Coverage for typed_stream/__main__.py: 76%

130 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-12 21:24 +0000

1# Licensed under the EUPL-1.2 or later. 

2# You may obtain a copy of the licence in all the official languages of the 

3# European Union at https://joinup.ec.europa.eu/collection/eupl/eupl-text-eupl-12 

4 

5"""Easy interface for streamy handling of files.""" 

6 

7from __future__ import annotations 

8 

9import argparse 

10import builtins 

11import collections 

12import dataclasses 

13import inspect 

14import operator 

15import sys 

16import textwrap 

17import traceback 

18from collections.abc import Callable, Mapping 

19from typing import cast 

20 

21from ._impl import Stream, functions 

22from ._impl._utils import count_required_positional_arguments 

23 

24EVAL_GLOBALS: Mapping[str, object] = dict( 

25 Stream((collections, functions, operator)) 

26 .flat_map(lambda mod: ((name, getattr(mod, name)) for name in mod.__all__)) 

27 .chain( 

28 (mod.__name__, mod) 

29 for mod in (builtins, collections, functions, operator) 

30 ) 

31) 

32 

33 

34@dataclasses.dataclass(frozen=True, slots=True) 

35class Options: 

36 """The options for this cool program.""" 

37 

38 debug: bool 

39 bytes: bool 

40 keep_ends: bool 

41 no_eval: bool 

42 actions: tuple[str, ...] 

43 

44 

45def run_program(options: Options) -> str | None: # noqa: C901 

46 # pylint: disable=too-complex, too-many-branches, too-many-locals, too-many-statements # noqa: B950 

47 """Run the program with the options. 

48 

49 >>> import contextlib, io 

50 >>> in_ = sys.stdin 

51 >>> sys.stdin = io.StringIO("200\\n1000\\n30\\n4") 

52 >>> with contextlib.redirect_stderr(io.StringIO()) as err: 

53 ... run_program(Options( 

54 ... debug=True, 

55 ... bytes=False, 

56 ... keep_ends=False, 

57 ... no_eval=False, 

58 ... actions=("map", "int", "sum"), 

59 ... )) 

60 1234 

61 >>> print("\\n".join(err.getvalue().split("\\n")[-2:])) 

62 print(Stream(sys.stdin).map(str.removesuffix, "\\n").map(int).sum()) 

63 <BLANKLINE> 

64 >>> sys.stdin = io.StringIO("300\\n1000\\n20\\n4") 

65 >>> with contextlib.redirect_stderr(io.StringIO()) as err: 

66 ... run_program(Options( 

67 ... debug=True, 

68 ... bytes=False, 

69 ... keep_ends=False, 

70 ... no_eval=False, 

71 ... actions=("map", "int", "collect", "builtins.sum") 

72 ... )) 

73 1324 

74 >>> print("\\n".join(err.getvalue().split("\\n")[-2:])) 

75 print(Stream(sys.stdin).map(str.removesuffix, "\\n").map(int).collect(builtins.sum)) 

76 <BLANKLINE> 

77 >>> sys.stdin = io.StringIO("") 

78 >>> with contextlib.redirect_stderr(io.StringIO()) as err: 

79 ... ret = run_program(Options( 

80 ... debug=True, 

81 ... bytes=False, 

82 ... keep_ends=False, 

83 ... no_eval=False, 

84 ... actions=("map", "int", "(°_°)") 

85 ... )) 

86 >>> assert not err.getvalue() 

87 >>> assert isinstance(ret, str) 

88 >>> assert "SyntaxError" in ret 

89 >>> assert "(°_°)" in ret 

90 >>> sys.stdin = io.StringIO("") 

91 >>> with contextlib.redirect_stderr(io.StringIO()) as err: 

92 ... ret = run_program(Options( 

93 ... debug=True, 

94 ... bytes=False, 

95 ... keep_ends=False, 

96 ... no_eval=False, 

97 ... actions=("map", "xxx") 

98 ... )) 

99 >>> assert not err.getvalue() 

100 >>> assert isinstance(ret, str) 

101 >>> assert "NameError" in ret 

102 >>> assert "xxx" in ret 

103 >>> sys.stdin = io.StringIO("") 

104 >>> with contextlib.redirect_stderr(io.StringIO()) as err: 

105 ... ret = run_program(Options( 

106 ... debug=True, 

107 ... bytes=False, 

108 ... keep_ends=False, 

109 ... no_eval=True, 

110 ... actions=("map", "xxx") 

111 ... )) 

112 >>> assert not err.getvalue() 

113 >>> print(ret) 

114 Can't parse 'xxx' without eval. 

115 >>> sys.stdin = io.StringIO("") 

116 >>> with contextlib.redirect_stderr(io.StringIO()) as err: 

117 ... ret = run_program(Options( 

118 ... debug=True, 

119 ... bytes=False, 

120 ... keep_ends=False, 

121 ... no_eval=True, 

122 ... actions=("map", "int", "collect", "sum") 

123 ... )) 

124 >>> assert not err.getvalue() 

125 >>> print(ret) 

126 StreamableSequence object has no attribute 'sum'. \ 

127To pass it as argument to Stream.collect use 'builtins.sum'. 

128 >>> sys.stdin = io.TextIOWrapper(io.BytesIO(b"200\\n1000\\n30\\n4")) 

129 >>> with contextlib.redirect_stderr(io.StringIO()) as err: 

130 ... run_program(Options( 

131 ... debug=True, 

132 ... bytes=True, 

133 ... keep_ends=True, 

134 ... no_eval=True, 

135 ... actions=("flat_map", "iter", "map", "hex", "collect", "Counter") 

136 ... )) 

137 Counter({'0x30': 6, '0xa': 3, '0x32': 1, '0x31': 1, '0x33': 1, '0x34': 1}) 

138 >>> print("\\n".join(err.getvalue().split("\\n")[-2:])) 

139 print(Stream(sys.stdin.buffer).flat_map(iter).map(hex).collect(collections.Counter)) 

140 <BLANKLINE> 

141 >>> sys.stdin = io.TextIOWrapper(io.BytesIO(b"1\\n2\\n3\\n4")) 

142 >>> with contextlib.redirect_stderr(io.StringIO()) as err: 

143 ... run_program(Options( 

144 ... debug=True, 

145 ... bytes=False, 

146 ... keep_ends=True, 

147 ... no_eval=True, 

148 ... actions=("map", "int", "filter", "is_even", "map", "mul", "10") 

149 ... )) 

150 20 

151 40 

152 >>> f"\\n{err.getvalue()}".endswith( 

153 ... "Stream(sys.stdin).map(int).filter(typed_stream.functions.is_even)" 

154 ... ".map(operator.mul,10).for_each(print)\\n" 

155 ... ) 

156 True 

157 >>> sys.stdin = in_ 

158 """ # noqa: D301 

159 code: list[str] 

160 stream: Stream[bytes] | Stream[str] | object 

161 if options.bytes: 

162 stream = Stream(sys.stdin.buffer) 

163 code = ["Stream(sys.stdin.buffer)"] 

164 if not options.keep_ends: 

165 code.append(r""".map(bytes.removesuffix, b"\n")""") 

166 stream = stream.map(bytes.removesuffix, b"\n") 

167 else: 

168 stream = Stream(sys.stdin) 

169 code = ["Stream(sys.stdin)"] 

170 if not options.keep_ends: 

171 code.append(r""".map(str.removesuffix, "\n")""") 

172 stream = stream.map(str.removesuffix, "\n") 

173 

174 method: None | Callable[[object], object] = None 

175 args: list[object] = [] 

176 for index, action in Stream(options.actions).enumerate(1): 

177 if action.lstrip().startswith("_"): 

178 return f"{index}: {action!r} isn't allowed to start with '_'." 

179 args_left = ( 

180 count_required_positional_arguments(method) - len(args) 

181 if method 

182 else 0 

183 ) 

184 if (not args_left or args_left < 0) and hasattr(stream, action): 

185 if method: 

186 stream = method(*args) # pylint: disable=not-callable 

187 args.clear() 

188 if code and code[-1] == ",": 

189 code[-1] = ")" 

190 else: 

191 code.append(")") 

192 if not hasattr(stream, action): 

193 type_name = ( 

194 type(stream).__qualname__ or type(stream).__name__ 

195 ) 

196 meth_name = method.__qualname__ or method.__name__ 

197 if hasattr(builtins, action): 

198 fix = f"builtins.{action}" 

199 confident = True 

200 else: 

201 fix = f"({action})" 

202 confident = action in EVAL_GLOBALS 

203 use = "use" if confident else "try" 

204 return ( 

205 f"{type_name} object has no attribute {action!r}. " 

206 f"To pass it as argument to {meth_name} {use} {fix!r}." 

207 ) 

208 method = getattr(stream, action) 

209 code.append(f".{action}(") 

210 else: 

211 if not method: 

212 return f"{action!r} needs to be a Stream method." 

213 full_action_qual: str 

214 if action.isspace(): 

215 args.append(action) 

216 full_action_qual = repr(action) 

217 elif action.isdigit(): 

218 args.append(int(action)) 

219 full_action_qual = action 

220 elif action in functions.__all__: 

221 args.append(getattr(functions, action)) 

222 full_action_qual = f"typed_stream.functions.{action}" 

223 elif action in collections.__all__: 

224 args.append(getattr(collections, action)) 

225 full_action_qual = f"collections.{action}" 

226 elif action in operator.__all__: 

227 args.append(getattr(operator, action)) 

228 full_action_qual = f"operator.{action}" 

229 elif hasattr(builtins, action): 

230 args.append(getattr(builtins, action)) 

231 full_action_qual = f"{action}" 

232 elif options.no_eval: 

233 return f"Can't parse {action!r} without eval." 

234 else: 

235 try: 

236 # pylint: disable-next=eval-used 

237 arg = eval(action, dict(EVAL_GLOBALS)) # nosec: B307 

238 # pylint: disable-next=broad-except 

239 except BaseException as exc: # noqa: B036 

240 err = traceback.format_exception_only(exc)[-1].strip() 

241 return f"Failed to evaluate {action!r}: {err}" 

242 args.append(arg) 

243 full_action_qual = action 

244 code.extend((full_action_qual, ",")) 

245 if method: 

246 if code and code[-1] == ",": 

247 code[-1] = ")" 

248 else: 

249 code.append(")") 

250 stream = method(*args) 

251 

252 if isinstance(stream, Stream): 

253 # pytype: disable=attribute-error 

254 stream.for_each(print) 

255 # pytype: enable=attribute-error 

256 code.append(".for_each(print)") 

257 elif stream: 

258 print(stream) 

259 code.insert(0, "print(") 

260 code.append(")") 

261 

262 sys.stdout.flush() 

263 

264 if options.debug: 

265 print("".join(code), file=sys.stderr, flush=True) 

266 return None 

267 

268 

269def main() -> str | None: # noqa: C901 

270 """Parse arguments and then run the program.""" 

271 arg_parser = argparse.ArgumentParser( 

272 prog="typed_stream", 

273 description="Easy interface for streamy handling of files.", 

274 epilog="Do not run this with arguments from an untrusted source.", 

275 ) 

276 arg_parser.add_argument("--debug", action="store_true") 

277 arg_parser.add_argument("--bytes", action="store_true") 

278 arg_parser.add_argument("--keep-ends", action="store_true") 

279 # arg_parser.add_argument("--no-eval", action="store_true") 

280 arg_parser.add_argument("actions", nargs="+") 

281 

282 args = arg_parser.parse_args() 

283 options = Options( 

284 debug=bool(args.debug), 

285 bytes=bool(args.bytes), 

286 keep_ends=bool(args.keep_ends), 

287 no_eval=False, 

288 actions=tuple(map(str, args.actions)), 

289 ) 

290 if options.actions and options.actions[0] == "help": 

291 if not (methods := options.actions[1:]): 

292 arg_parser.parse_args([sys.argv[0], "--help"]) 

293 

294 for i, name in enumerate(methods): 

295 if i: 

296 print() 

297 print(f"Stream.{name}:") 

298 if not (method := getattr(Stream, name, None)): 

299 to_print = "Does not exist." 

300 elif not (doc := cast(str, getattr(method, "__doc__", ""))): 

301 to_print = "No docs." 

302 else: 

303 to_print = inspect.cleandoc(doc) 

304 

305 print(textwrap.indent(to_print, " " * 4)) 

306 return None 

307 

308 return run_program(options) 

309 

310 

311if __name__ == "__main__": 

312 sys.exit(main())