File: generate_multiplexer.go

package info (click to toggle)
golang-github-lucas-clemente-quic-go 0.54.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 4,312 kB
  • sloc: sh: 54; makefile: 7
file content (161 lines) | stat: -rw-r--r-- 4,315 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
//go:build generate

package main

import (
	"bytes"
	"fmt"
	"go/ast"
	"go/parser"
	"go/printer"
	"go/token"
	"log"
	"os"
	"strings"
	"text/template"

	"golang.org/x/tools/imports"
)

func main() {
	if len(os.Args) != 5 {
		log.Fatalf("Usage: %s <struct_name> <input_file> <template_file> <output_file>", os.Args[0])
	}

	structName := os.Args[1]
	inputFile := os.Args[2]
	templateFile := os.Args[3]
	outputFile := os.Args[4]

	fset := token.NewFileSet()

	// Parse the input file containing the struct type
	file, err := parser.ParseFile(fset, inputFile, nil, parser.AllErrors)
	if err != nil {
		log.Fatalf("Failed to parse file: %v", err)
	}

	var fields []*ast.Field

	// Find the specified struct type in the AST
	for _, decl := range file.Decls {
		genDecl, ok := decl.(*ast.GenDecl)
		if !ok || genDecl.Tok != token.TYPE {
			continue
		}
		for _, spec := range genDecl.Specs {
			typeSpec, ok := spec.(*ast.TypeSpec)
			if !ok || typeSpec.Name.Name != structName {
				continue
			}
			structType, ok := typeSpec.Type.(*ast.StructType)
			if !ok {
				log.Fatalf("%s is not a struct", structName)
			}
			fields = structType.Fields.List
			break
		}
	}

	if fields == nil {
		log.Fatalf("Could not find %s type", structName)
	}

	// Prepare data for the template
	type FieldData struct {
		Name        string
		Params      string
		Args        string
		HasParams   bool
		ReturnTypes string
		HasReturn   bool
	}

	var fieldDataList []FieldData

	for _, field := range fields {
		funcType, ok := field.Type.(*ast.FuncType)
		if !ok {
			continue
		}
		for _, name := range field.Names {
			fieldData := FieldData{Name: name.Name}

			// extract parameters
			var params []string
			var args []string
			if funcType.Params != nil {
				for i, param := range funcType.Params.List {
					// We intentionally reject unnamed (and, further down, "_") function parameters.
					// We could auto-generate parameter names,
					// but having meaningful variable names will be more helpful for the user.
					if len(param.Names) == 0 {
						log.Fatalf("encountered unnamed parameter at position %d in function %s", i, fieldData.Name)
					}
					var buf bytes.Buffer
					printer.Fprint(&buf, fset, param.Type)
					paramType := buf.String()
					for _, paramName := range param.Names {
						if paramName.Name == "_" {
							log.Fatalf("encountered underscore parameter at position %d in function %s", i, fieldData.Name)
						}
						params = append(params, fmt.Sprintf("%s %s", paramName.Name, paramType))
						args = append(args, paramName.Name)
					}
				}
			}
			fieldData.Params = strings.Join(params, ", ")
			fieldData.Args = strings.Join(args, ", ")
			fieldData.HasParams = len(params) > 0

			// extract return types
			if funcType.Results != nil && len(funcType.Results.List) > 0 {
				fieldData.HasReturn = true
				var returns []string
				for _, result := range funcType.Results.List {
					var buf bytes.Buffer
					printer.Fprint(&buf, fset, result.Type)
					returns = append(returns, buf.String())
				}
				if len(returns) == 1 {
					fieldData.ReturnTypes = fmt.Sprintf(" %s", returns[0])
				} else {
					fieldData.ReturnTypes = fmt.Sprintf(" (%s)", strings.Join(returns, ", "))
				}
			}

			fieldDataList = append(fieldDataList, fieldData)
		}
	}

	// Read the template from file
	templateContent, err := os.ReadFile(templateFile)
	if err != nil {
		log.Fatalf("Failed to read template file: %v", err)
	}

	// Generate the code using the template
	tmpl, err := template.New("multiplexer").Funcs(template.FuncMap{"join": strings.Join}).Parse(string(templateContent))
	if err != nil {
		log.Fatalf("Failed to parse template: %v", err)
	}

	var generatedCode bytes.Buffer
	generatedCode.WriteString("// Code generated by generate_multiplexer.go; DO NOT EDIT.\n\n")
	if err = tmpl.Execute(&generatedCode, map[string]interface{}{
		"Fields":     fieldDataList,
		"StructName": structName,
	}); err != nil {
		log.Fatalf("Failed to execute template: %v", err)
	}

	// Format the generated code and add imports
	formattedCode, err := imports.Process(outputFile, generatedCode.Bytes(), nil)
	if err != nil {
		log.Fatalf("Failed to process imports: %v", err)
	}

	if err := os.WriteFile(outputFile, formattedCode, 0o644); err != nil {
		log.Fatalf("Failed to write output file: %v", err)
	}
}