File: assert.go

package info (click to toggle)
golang-github-bruth-assert 0.0+git20130823.de420fa-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 72 kB
  • sloc: makefile: 2
file content (129 lines) | stat: -rw-r--r-- 3,562 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package assert

import (
    "fmt"
    "github.com/kr/pretty"
    "reflect"
    "runtime"
    "strings"
    "testing"
)

var errorPrefix = "! "

// -- Assertion handlers

func assert(t *testing.T, success bool, f func(), callDepth int) {
    if !success {
        _, file, line, _ := runtime.Caller(callDepth + 1)
        t.Errorf("%s:%d", file, line)
        f()
        t.FailNow()
    }
}

func equal(t *testing.T, expected, got interface{}, callDepth int, messages ...interface{}) {
    fn := func() {
        for _, desc := range pretty.Diff(expected, got) {
            t.Error(errorPrefix, desc)
        }
        if len(messages) > 0 {
            t.Error(errorPrefix, "-", fmt.Sprint(messages...))
        }
    }
    assert(t, isEqual(expected, got), fn, callDepth+1)
}

func notEqual(t *testing.T, expected, got interface{}, callDepth int, messages ...interface{}) {
    fn := func() {
        t.Errorf("%s Unexpected: %#v", errorPrefix, got)
        if len(messages) > 0 {
            t.Error(errorPrefix, "-", fmt.Sprint(messages...))
        }
    }
    assert(t, !isEqual(expected, got), fn, callDepth+1)
}

func contains(t *testing.T, expected, got string, callDepth int, messages ...interface{}) {
    fn := func() {
        t.Errorf("%s Expected to find: %#v", errorPrefix, expected)
        t.Errorf("%s in: %#v", errorPrefix, got)
        if len(messages) > 0 {
            t.Error(errorPrefix, "-", fmt.Sprint(messages...))
        }
    }
    assert(t, strings.Contains(got, expected), fn, callDepth+1)
}

func notContains(t *testing.T, unexpected, got string, callDepth int, messages ...interface{}) {
    fn := func() {
        t.Errorf("%s Expected not to find: %#v", errorPrefix, unexpected)
        t.Errorf("%s in: %#v", errorPrefix, got)
        if len(messages) > 0 {
            t.Error(errorPrefix, "-", fmt.Sprint(messages...))
        }
    }
    assert(t, !strings.Contains(got, unexpected), fn, callDepth+1)
}

// -- Matching

func isEqual(expected, got interface{}) bool {
    if expected == nil {
        return isNil(got)
    }
    return reflect.DeepEqual(expected, got)
}

func isNil(got interface{}) bool {
    if got == nil {
        return true
    }
    value := reflect.ValueOf(got)
    switch value.Kind() {
    case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
        return value.IsNil()
    }
    return false
}

// -- Public API

func Equal(t *testing.T, expected, got interface{}, messages ...interface{}) {
    equal(t, expected, got, 1, messages...)
}

func NotEqual(t *testing.T, expected, got interface{}, messages ...interface{}) {
    notEqual(t, expected, got, 1, messages...)
}

func True(t *testing.T, got interface{}, messages ...interface{}) {
    equal(t, true, got, 1, messages...)
}

func False(t *testing.T, got interface{}, messages ...interface{}) {
    equal(t, false, got, 1, messages...)
}

func Nil(t *testing.T, got interface{}, messages ...interface{}) {
    equal(t, nil, got, 1, messages...)
}

func NotNil(t *testing.T, got interface{}, messages ...interface{}) {
    notEqual(t, nil, got, 1, messages...)
}

func Contains(t *testing.T, expected, got string, messages ...interface{}) {
    contains(t, expected, got, 1, messages...)
}

func NotContains(t *testing.T, unexpected, got string, messages ...interface{}) {
    notContains(t, unexpected, got, 1, messages...)
}

func Panic(t *testing.T, err interface{}, fn func(), messages ...interface{}) {
    defer func() {
        equal(t, err, recover(), 3, messages...)
    }()
    fn()
}