File: check_ast_test.go

package info (click to toggle)
elvish 0.21.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 6,372 kB
  • sloc: javascript: 236; sh: 130; python: 104; makefile: 88; xml: 9
file content (153 lines) | stat: -rw-r--r-- 4,416 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package parse

import (
	"fmt"
	"reflect"
	"strings"
	"unicode"
	"unicode/utf8"
)

// AST checking utilities. Used in test cases.

// ast is an AST specification. The name part identifies the type of the Node;
// for instance, "Chunk" specifies a Chunk. The fields part is specifies children
// to check; see document of fs.
//
// When a Node contains exactly one child, It can be coalesced with its child
// by adding "/ChildName" in the name part. For instance, "Chunk/Pipeline"
// specifies a Chunk that contains exactly one Pipeline. In this case, the
// fields part specified the children of the Pipeline instead of the Chunk
// (which has no additional interesting fields anyway). Multi-level coalescence
// like "Chunk/Pipeline/Form" is also allowed.
//
// The dynamic type of the Node being checked is assumed to be a pointer to a
// struct that embeds the "node" struct.
type ast struct {
	name   string
	fields fs
}

// fs specifies fields of a Node to check. For the value of field $f in the
// Node ("found value"), fs[$f] ("wanted value") is used to check against it.
//
// If the key is "text", the SourceText of the Node is checked. It doesn't
// involve a found value.
//
// If the wanted value is nil, the found value is checked against nil.
//
// If the found value implements Node, then the wanted value must be either an
// ast, where the checking algorithm of ast applies, or a string, where the
// source text of the found value is checked.
//
// If the found value is a slice whose elements implement Node, then the wanted
// value must be a slice where checking is then done recursively.
//
// If the found value satisfied none of the above conditions, it is checked
// against the wanted value using reflect.DeepEqual.
type fs map[string]any

// checkAST checks an AST against a specification.
func checkAST(n Node, want ast) error {
	wantnames := strings.Split(want.name, "/")
	// Check coalesced levels
	for i, wantname := range wantnames {
		name := reflect.TypeOf(n).Elem().Name()
		if wantname != name {
			return fmt.Errorf("want %s, got %s (%s)", wantname, name, summary(n))
		}
		if i == len(wantnames)-1 {
			break
		}
		fields := Children(n)
		if len(fields) != 1 {
			return fmt.Errorf("want exactly 1 child, got %d (%s)", len(fields), summary(n))
		}
		n = fields[0]
	}

	ntype := reflect.TypeOf(n).Elem()
	nvalue := reflect.ValueOf(n).Elem()

	for i := 0; i < ntype.NumField(); i++ {
		fieldname := ntype.Field(i).Name
		if !exported(fieldname) {
			// Unexported field
			continue
		}
		got := nvalue.Field(i).Interface()
		want, ok := want.fields[fieldname]
		if ok {
			err := checkField(got, want, "field "+fieldname+" of: "+summary(n))
			if err != nil {
				return err
			}
		} else {
			// Not specified. Check if got is a zero value of its type.
			zero := reflect.Zero(reflect.TypeOf(got)).Interface()
			if !reflect.DeepEqual(got, zero) {
				return fmt.Errorf("want %v, got %v (field %s of: %s)", zero, got, fieldname, summary(n))
			}
		}
	}

	return nil
}

// checkField checks a field against a field specification.
func checkField(got any, want any, ctx string) error {
	// Want nil.
	if want == nil {
		if !reflect.ValueOf(got).IsNil() {
			return fmt.Errorf("want nil, got %v (%s)", got, ctx)
		}
		return nil
	}

	if got, ok := got.(Node); ok {
		// Got a Node.
		return checkNodeInField(got, want)
	}
	tgot := reflect.TypeOf(got)
	if tgot.Kind() == reflect.Slice && tgot.Elem().Implements(nodeType) {
		// Got a slice of Nodes.
		vgot := reflect.ValueOf(got)
		vwant := reflect.ValueOf(want)
		if vgot.Len() != vwant.Len() {
			return fmt.Errorf("want %d, got %d (%s)", vwant.Len(), vgot.Len(), ctx)
		}
		for i := 0; i < vgot.Len(); i++ {
			err := checkNodeInField(vgot.Index(i).Interface().(Node),
				vwant.Index(i).Interface())
			if err != nil {
				return err
			}
		}
		return nil
	}

	if !reflect.DeepEqual(want, got) {
		return fmt.Errorf("want %v, got %v (%s)", want, got, ctx)
	}
	return nil
}

func checkNodeInField(got Node, want any) error {
	switch want := want.(type) {
	case string:
		text := SourceText(got)
		if want != text {
			return fmt.Errorf("want %q, got %q (%s)", want, text, summary(got))
		}
		return nil
	case ast:
		return checkAST(got, want)
	default:
		panic(fmt.Sprintf("bad want type %T (%s)", want, summary(got)))
	}
}

func exported(name string) bool {
	r, _ := utf8.DecodeRuneInString(name)
	return unicode.IsUpper(r)
}