1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
|
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package test
import (
"sort"
"golang.org/x/vuln/internal/govulncheck"
"golang.org/x/vuln/internal/osv"
)
// MockHandler implements govulncheck.Handler but (currently)
// does nothing.
//
// For use in tests.
type MockHandler struct {
ConfigMessages []*govulncheck.Config
ProgressMessages []*govulncheck.Progress
OSVMessages []*osv.Entry
FindingMessages []*govulncheck.Finding
}
func NewMockHandler() *MockHandler {
return &MockHandler{}
}
func (h *MockHandler) Config(config *govulncheck.Config) error {
h.ConfigMessages = append(h.ConfigMessages, config)
return nil
}
func (h *MockHandler) Progress(progress *govulncheck.Progress) error {
h.ProgressMessages = append(h.ProgressMessages, progress)
return nil
}
func (h *MockHandler) OSV(entry *osv.Entry) error {
h.OSVMessages = append(h.OSVMessages, entry)
return nil
}
func (h *MockHandler) Finding(finding *govulncheck.Finding) error {
h.FindingMessages = append(h.FindingMessages, finding)
return nil
}
func (h *MockHandler) Sort() {
sort.Slice(h.FindingMessages, func(i, j int) bool {
if h.FindingMessages[i].OSV > h.FindingMessages[j].OSV {
return true
}
if h.FindingMessages[i].OSV < h.FindingMessages[j].OSV {
return false
}
iframe := h.FindingMessages[i].Trace[0]
jframe := h.FindingMessages[j].Trace[0]
if iframe.Module < jframe.Module {
return true
}
if iframe.Module > jframe.Module {
return false
}
if iframe.Package < jframe.Package {
return true
}
if iframe.Package > jframe.Package {
return false
}
return iframe.Function < jframe.Function
})
}
func (h *MockHandler) Write(to govulncheck.Handler) error {
h.Sort()
for _, config := range h.ConfigMessages {
if err := to.Config(config); err != nil {
return err
}
}
for _, progress := range h.ProgressMessages {
if err := to.Progress(progress); err != nil {
return err
}
}
seen := map[string]bool{}
for _, finding := range h.FindingMessages {
if !seen[finding.OSV] {
seen[finding.OSV] = true
// first time seeing this osv, so find and write the osv message
for _, osv := range h.OSVMessages {
if osv.ID == finding.OSV {
if err := to.OSV(osv); err != nil {
return err
}
}
}
}
if err := to.Finding(finding); err != nil {
return err
}
}
for _, osv := range h.OSVMessages {
if !seen[osv.ID] {
if err := to.OSV(osv); err != nil {
return err
}
}
}
return nil
}
|