1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
|
package command
import (
"fmt"
"go/build"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
)
func Detect(cwd string, args []string, generateMode bool) ([]Invocation, error) {
if generateMode {
return generateModeInvocations(cwd)
}
file := os.Getenv("GOFILE")
var lineno int
if goline, err := strconv.Atoi(os.Getenv("GOLINE")); err == nil {
lineno = goline
}
i, err := NewInvocation(file, lineno, args)
if err != nil {
return nil, err
}
return []Invocation{i}, nil
}
type Invocation struct {
Args []string
Line int
File string
}
func NewInvocation(file string, line int, args []string) (Invocation, error) {
if len(args) < 1 {
return Invocation{}, fmt.Errorf("%s:%v an invocation of counterfeiter must have arguments", file, line)
}
i := Invocation{
File: file,
Line: line,
Args: args,
}
return i, nil
}
func generateModeInvocations(cwd string) ([]Invocation, error) {
var result []Invocation
// Find all the go files
pkg, err := build.ImportDir(cwd, build.IgnoreVendor)
if err != nil {
return nil, err
}
gofiles := make([]string, 0, len(pkg.GoFiles)+len(pkg.CgoFiles)+len(pkg.TestGoFiles)+len(pkg.XTestGoFiles))
gofiles = append(gofiles, pkg.GoFiles...)
gofiles = append(gofiles, pkg.CgoFiles...)
gofiles = append(gofiles, pkg.TestGoFiles...)
gofiles = append(gofiles, pkg.XTestGoFiles...)
sort.Strings(gofiles)
for _, file := range gofiles {
invocations, err := invocationsInFile(cwd, file)
if err != nil {
return nil, err
}
result = append(result, invocations...)
}
return result, nil
}
func invocationsInFile(dir string, file string) ([]Invocation, error) {
str, err := os.ReadFile(filepath.Join(dir, file))
if err != nil {
return nil, err
}
lines := strings.Split(string(str), "\n")
var result []Invocation
line := 0
for i := range lines {
line++
args, ok := matchForString(lines[i])
if !ok {
continue
}
inv, err := NewInvocation(file, line, args)
if err != nil {
return nil, err
}
result = append(result, inv)
}
return result, nil
}
const generateDirectivePrefix = "//counterfeiter:generate "
func matchForString(s string) ([]string, bool) {
if !strings.HasPrefix(s, generateDirectivePrefix) {
return nil, false
}
return stringToArgs(s[len(generateDirectivePrefix):]), true
}
func stringToArgs(s string) []string {
a := strings.Fields(s)
result := []string{
"counterfeiter",
}
result = append(result, a...)
return result
}
|