File: copier.go

package info (click to toggle)
golang-golang-x-tools 1%3A0.25.0%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental, forky, sid, trixie
  • size: 22,724 kB
  • sloc: javascript: 2,027; asm: 1,645; sh: 166; yacc: 155; makefile: 49; ansic: 8
file content (142 lines) | stat: -rw-r--r-- 3,918 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build ignore
// +build ignore

//go:generate go run ./copier.go

// Copier is a tool to automate copy of govulncheck's internal files.
//
//   - copy golang.org/x/vuln/internal/osv/ to osv
//   - copy golang.org/x/vuln/internal/govulncheck/ to govulncheck
package main

import (
	"bytes"
	"encoding/json"
	"fmt"
	"go/parser"
	"go/token"
	"log"
	"os"
	"os/exec"
	"path/filepath"
	"strconv"
	"strings"

	"golang.org/x/tools/internal/edit"
)

func main() {
	log.SetPrefix("copier: ")
	log.SetFlags(log.Lshortfile)

	srcMod := "golang.org/x/vuln"
	srcModVers := "@latest"
	srcDir, srcVer := downloadModule(srcMod + srcModVers)

	cfg := rewrite{
		banner:        fmt.Sprintf("// Code generated by copying from %v@%v (go run copier.go); DO NOT EDIT.", srcMod, srcVer),
		srcImportPath: "golang.org/x/vuln/internal",
		dstImportPath: currentPackagePath(),
	}

	copyFiles("osv", filepath.Join(srcDir, "internal", "osv"), cfg)
	copyFiles("govulncheck", filepath.Join(srcDir, "internal", "govulncheck"), cfg)
}

type rewrite struct {
	// DO NOT EDIT marker to add at the beginning
	banner string
	// rewrite srcImportPath with dstImportPath
	srcImportPath string
	dstImportPath string
}

func copyFiles(dst, src string, cfg rewrite) {
	entries, err := os.ReadDir(src)
	if err != nil {
		log.Fatalf("failed to read dir: %v", err)
	}
	if err := os.MkdirAll(dst, 0777); err != nil {
		log.Fatalf("failed to create dir: %v", err)
	}

	for _, e := range entries {
		fname := e.Name()
		// we need only non-test go files.
		if e.IsDir() || !strings.HasSuffix(fname, ".go") || strings.HasSuffix(fname, "_test.go") {
			continue
		}
		data, err := os.ReadFile(filepath.Join(src, fname))
		if err != nil {
			log.Fatal(err)
		}
		fset := token.NewFileSet()
		f, err := parser.ParseFile(fset, fname, data, parser.ParseComments|parser.ImportsOnly)
		if err != nil {
			log.Fatalf("parsing source module:\n%s", err)
		}

		buf := edit.NewBuffer(data)
		at := func(p token.Pos) int {
			return fset.File(p).Offset(p)
		}

		// Add banner right after the copyright statement (the first comment)
		bannerInsert, banner := f.FileStart, cfg.banner
		if len(f.Comments) > 0 && strings.HasPrefix(f.Comments[0].Text(), "Copyright ") {
			bannerInsert = f.Comments[0].End()
			banner = "\n\n" + banner
		}
		buf.Replace(at(bannerInsert), at(bannerInsert), banner)

		// Adjust imports
		for _, spec := range f.Imports {
			path, err := strconv.Unquote(spec.Path.Value)
			if err != nil {
				log.Fatal(err)
			}
			if strings.HasPrefix(path, cfg.srcImportPath) {
				newPath := strings.Replace(path, cfg.srcImportPath, cfg.dstImportPath, 1)
				buf.Replace(at(spec.Path.Pos()), at(spec.Path.End()), strconv.Quote(newPath))
			}
		}
		data = buf.Bytes()

		if err := os.WriteFile(filepath.Join(dst, fname), data, 0666); err != nil {
			log.Fatal(err)
		}
	}
}

func downloadModule(srcModVers string) (dir, ver string) {
	var stdout, stderr bytes.Buffer
	cmd := exec.Command("go", "mod", "download", "-json", srcModVers)
	cmd.Stdout = &stdout
	cmd.Stderr = &stderr
	if err := cmd.Run(); err != nil {
		log.Fatalf("go mod download -json %s: %v\n%s%s", srcModVers, err, stderr.Bytes(), stdout.Bytes())
	}
	var info struct {
		Dir     string
		Version string
	}
	if err := json.Unmarshal(stdout.Bytes(), &info); err != nil {
		log.Fatalf("go mod download -json %s: invalid JSON output: %v\n%s%s", srcModVers, err, stderr.Bytes(), stdout.Bytes())
	}
	return info.Dir, info.Version
}

func currentPackagePath() string {
	var stdout, stderr bytes.Buffer
	cmd := exec.Command("go", "list", ".")
	cmd.Stdout = &stdout
	cmd.Stderr = &stderr
	if err := cmd.Run(); err != nil {
		log.Fatalf("go list: %v\n%s%s", err, stderr.Bytes(), stdout.Bytes())
	}
	return strings.TrimSpace(stdout.String())
}