1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
|
from dataclasses import dataclass
from mypy.nodes import (
ArgKind,
CallExpr,
GeneratorExpr,
ListComprehension,
NameExpr,
SetComprehension,
TupleExpr,
)
from refurb.error import Error
@dataclass
class ErrorInfo(Error):
"""
If you only want to iterate and unpack values so that you can pass them
to a function (in the same order and with no modifications), you should
use the more performant `starmap` function:
Bad:
```
scores = [85, 100, 60]
passing_scores = [60, 80, 70]
def passed_test(score: int, passing_score: int) -> bool:
return score >= passing_score
passed_all_tests = all(
passed_test(score, passing_score)
for score, passing_score
in zip(scores, passing_scores)
)
```
Good:
```
from itertools import starmap
scores = [85, 100, 60]
passing_scores = [60, 80, 70]
def passed_test(score: int, passing_score: int) -> bool:
return score >= passing_score
passed_all_tests = all(starmap(passed_test, zip(scores, passing_scores)))
```
"""
name = "use-starmap"
code = 140
msg: str = "Replace `f(...) for ... in x` with `starmap(f, x)`"
categories = ("itertools", "performance")
ignore = set[int]()
def check_generator(
node: GeneratorExpr,
errors: list[Error],
old_wrapper: str = "{}",
new_wrapper: str = "{}",
) -> None:
match node:
case GeneratorExpr(
left_expr=CallExpr(args=args, arg_kinds=arg_kinds),
indices=[TupleExpr(items=names)],
) if (
names
and len(names) == len(args)
and all(kind == ArgKind.ARG_POS for kind in arg_kinds)
):
for lhs, rhs in zip(args, names):
if not (
isinstance(lhs, NameExpr)
and isinstance(rhs, NameExpr)
and lhs.name == rhs.name
):
return
ignore.add(id(node))
old = "f(...) for ... in x"
old = old_wrapper.format(old)
new = "starmap(f, x)"
new = new_wrapper.format(new)
msg = f"Replace `{old}` with `{new}`"
errors.append(ErrorInfo.from_node(node, msg))
def check(
node: GeneratorExpr | ListComprehension | SetComprehension,
errors: list[Error],
) -> None:
if id(node) in ignore:
return
match node:
case GeneratorExpr():
check_generator(node, errors)
case ListComprehension(generator=g):
check_generator(
g,
errors,
old_wrapper="[{}]",
new_wrapper="list({})",
)
case SetComprehension(generator=g):
check_generator(
g,
errors,
old_wrapper="{{{}}}",
new_wrapper="set({})",
)
|