1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
|
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package infertypeargs
import (
"go/ast"
"go/token"
"go/types"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/inspector"
"golang.org/x/tools/internal/typeparams"
"golang.org/x/tools/internal/versions"
)
const Doc = `check for unnecessary type arguments in call expressions
Explicit type arguments may be omitted from call expressions if they can be
inferred from function arguments, or from other type arguments:
func f[T any](T) {}
func _() {
f[string]("foo") // string could be inferred
}
`
var Analyzer = &analysis.Analyzer{
Name: "infertypeargs",
Doc: Doc,
Requires: []*analysis.Analyzer{inspect.Analyzer},
Run: run,
URL: "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/infertypeargs",
}
func run(pass *analysis.Pass) (any, error) {
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
for _, diag := range diagnose(pass.Fset, inspect, token.NoPos, token.NoPos, pass.Pkg, pass.TypesInfo) {
pass.Report(diag)
}
return nil, nil
}
// Diagnose reports diagnostics describing simplifications to type
// arguments overlapping with the provided start and end position.
//
// If start or end is token.NoPos, the corresponding bound is not checked
// (i.e. if both start and end are NoPos, all call expressions are considered).
func diagnose(fset *token.FileSet, inspect *inspector.Inspector, start, end token.Pos, pkg *types.Package, info *types.Info) []analysis.Diagnostic {
var diags []analysis.Diagnostic
nodeFilter := []ast.Node{(*ast.CallExpr)(nil)}
inspect.Preorder(nodeFilter, func(node ast.Node) {
call := node.(*ast.CallExpr)
x, lbrack, indices, rbrack := typeparams.UnpackIndexExpr(call.Fun)
ident := calledIdent(x)
if ident == nil || len(indices) == 0 {
return // no explicit args, nothing to do
}
if (start.IsValid() && call.End() < start) || (end.IsValid() && call.Pos() > end) {
return // non-overlapping
}
// Confirm that instantiation actually occurred at this ident.
idata, ok := info.Instances[ident]
if !ok {
return // something went wrong, but fail open
}
instance := idata.Type
// Start removing argument expressions from the right, and check if we can
// still infer the call expression.
required := len(indices) // number of type expressions that are required
for i := len(indices) - 1; i >= 0; i-- {
var fun ast.Expr
if i == 0 {
// No longer an index expression: just use the parameterized operand.
fun = x
} else {
fun = typeparams.PackIndexExpr(x, lbrack, indices[:i], indices[i-1].End())
}
newCall := &ast.CallExpr{
Fun: fun,
Lparen: call.Lparen,
Args: call.Args,
Ellipsis: call.Ellipsis,
Rparen: call.Rparen,
}
info := &types.Info{
Instances: make(map[*ast.Ident]types.Instance),
}
versions.InitFileVersions(info)
if err := types.CheckExpr(fset, pkg, call.Pos(), newCall, info); err != nil {
// Most likely inference failed.
break
}
newIData := info.Instances[ident]
newInstance := newIData.Type
if !types.Identical(instance, newInstance) {
// The inferred result type does not match the original result type, so
// this simplification is not valid.
break
}
required = i
}
if required < len(indices) {
var s, e token.Pos
var edit analysis.TextEdit
if required == 0 {
s, e = lbrack, rbrack+1 // erase the entire index
edit = analysis.TextEdit{Pos: s, End: e}
} else {
s = indices[required].Pos()
e = rbrack
// erase from end of last arg to include last comma & white-spaces
edit = analysis.TextEdit{Pos: indices[required-1].End(), End: e}
}
// Recheck that our (narrower) fixes overlap with the requested range.
if (start.IsValid() && e < start) || (end.IsValid() && s > end) {
return // non-overlapping
}
diags = append(diags, analysis.Diagnostic{
Pos: s,
End: e,
Message: "unnecessary type arguments",
SuggestedFixes: []analysis.SuggestedFix{{
Message: "Simplify type arguments",
TextEdits: []analysis.TextEdit{edit},
}},
})
}
})
return diags
}
func calledIdent(x ast.Expr) *ast.Ident {
switch x := x.(type) {
case *ast.Ident:
return x
case *ast.SelectorExpr:
return x.Sel
}
return nil
}
|