1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
|
//go:build generate
package main
import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"log"
"os"
"strings"
"text/template"
"golang.org/x/tools/imports"
)
func main() {
if len(os.Args) != 5 {
log.Fatalf("Usage: %s <struct_name> <input_file> <template_file> <output_file>", os.Args[0])
}
structName := os.Args[1]
inputFile := os.Args[2]
templateFile := os.Args[3]
outputFile := os.Args[4]
fset := token.NewFileSet()
// Parse the input file containing the struct type
file, err := parser.ParseFile(fset, inputFile, nil, parser.AllErrors)
if err != nil {
log.Fatalf("Failed to parse file: %v", err)
}
var fields []*ast.Field
// Find the specified struct type in the AST
for _, decl := range file.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok || typeSpec.Name.Name != structName {
continue
}
structType, ok := typeSpec.Type.(*ast.StructType)
if !ok {
log.Fatalf("%s is not a struct", structName)
}
fields = structType.Fields.List
break
}
}
if fields == nil {
log.Fatalf("Could not find %s type", structName)
}
// Prepare data for the template
type FieldData struct {
Name string
Params string
Args string
HasParams bool
ReturnTypes string
HasReturn bool
}
var fieldDataList []FieldData
for _, field := range fields {
funcType, ok := field.Type.(*ast.FuncType)
if !ok {
continue
}
for _, name := range field.Names {
fieldData := FieldData{Name: name.Name}
// extract parameters
var params []string
var args []string
if funcType.Params != nil {
for i, param := range funcType.Params.List {
// We intentionally reject unnamed (and, further down, "_") function parameters.
// We could auto-generate parameter names,
// but having meaningful variable names will be more helpful for the user.
if len(param.Names) == 0 {
log.Fatalf("encountered unnamed parameter at position %d in function %s", i, fieldData.Name)
}
var buf bytes.Buffer
printer.Fprint(&buf, fset, param.Type)
paramType := buf.String()
for _, paramName := range param.Names {
if paramName.Name == "_" {
log.Fatalf("encountered underscore parameter at position %d in function %s", i, fieldData.Name)
}
params = append(params, fmt.Sprintf("%s %s", paramName.Name, paramType))
args = append(args, paramName.Name)
}
}
}
fieldData.Params = strings.Join(params, ", ")
fieldData.Args = strings.Join(args, ", ")
fieldData.HasParams = len(params) > 0
// extract return types
if funcType.Results != nil && len(funcType.Results.List) > 0 {
fieldData.HasReturn = true
var returns []string
for _, result := range funcType.Results.List {
var buf bytes.Buffer
printer.Fprint(&buf, fset, result.Type)
returns = append(returns, buf.String())
}
if len(returns) == 1 {
fieldData.ReturnTypes = fmt.Sprintf(" %s", returns[0])
} else {
fieldData.ReturnTypes = fmt.Sprintf(" (%s)", strings.Join(returns, ", "))
}
}
fieldDataList = append(fieldDataList, fieldData)
}
}
// Read the template from file
templateContent, err := os.ReadFile(templateFile)
if err != nil {
log.Fatalf("Failed to read template file: %v", err)
}
// Generate the code using the template
tmpl, err := template.New("multiplexer").Funcs(template.FuncMap{"join": strings.Join}).Parse(string(templateContent))
if err != nil {
log.Fatalf("Failed to parse template: %v", err)
}
var generatedCode bytes.Buffer
generatedCode.WriteString("// Code generated by generate_multiplexer.go; DO NOT EDIT.\n\n")
if err = tmpl.Execute(&generatedCode, map[string]interface{}{
"Fields": fieldDataList,
"StructName": structName,
}); err != nil {
log.Fatalf("Failed to execute template: %v", err)
}
// Format the generated code and add imports
formattedCode, err := imports.Process(outputFile, generatedCode.Bytes(), nil)
if err != nil {
log.Fatalf("Failed to process imports: %v", err)
}
if err := os.WriteFile(outputFile, formattedCode, 0o644); err != nil {
log.Fatalf("Failed to write output file: %v", err)
}
}
|