1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
|
package sftp
import (
"encoding"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
type _testSender struct {
sent chan encoding.BinaryMarshaler
}
func newTestSender() *_testSender {
return &_testSender{make(chan encoding.BinaryMarshaler)}
}
func (s _testSender) sendPacket(p encoding.BinaryMarshaler) error {
s.sent <- p
return nil
}
type fakepacket struct {
reqid uint32
oid uint32
}
func fake(rid, order uint32) fakepacket {
return fakepacket{reqid: rid, oid: order}
}
func (fakepacket) MarshalBinary() ([]byte, error) {
return make([]byte, 4), nil
}
func (fakepacket) UnmarshalBinary([]byte) error {
return nil
}
func (f fakepacket) id() uint32 {
return f.reqid
}
type pair struct {
in, out fakepacket
}
type orderedPair struct {
in orderedRequest
out orderedResponse
}
// basic test
var ttable1 = []pair{
{fake(0, 0), fake(0, 0)},
{fake(1, 1), fake(1, 1)},
{fake(2, 2), fake(2, 2)},
{fake(3, 3), fake(3, 3)},
}
// outgoing packets out of order
var ttable2 = []pair{
{fake(10, 0), fake(12, 2)},
{fake(11, 1), fake(11, 1)},
{fake(12, 2), fake(13, 3)},
{fake(13, 3), fake(10, 0)},
}
// request ids are not incremental
var ttable3 = []pair{
{fake(7, 0), fake(7, 0)},
{fake(1, 1), fake(1, 1)},
{fake(9, 2), fake(3, 3)},
{fake(3, 3), fake(9, 2)},
}
// request ids are all the same
var ttable4 = []pair{
{fake(1, 0), fake(1, 0)},
{fake(1, 1), fake(1, 1)},
{fake(1, 2), fake(1, 3)},
{fake(1, 3), fake(1, 2)},
}
var tables = [][]pair{ttable1, ttable2, ttable3, ttable4}
func TestPacketManager(t *testing.T) {
sender := newTestSender()
s := newPktMgr(sender)
for i := range tables {
table := tables[i]
orderedPairs := make([]orderedPair, 0, len(table))
for _, p := range table {
orderedPairs = append(orderedPairs, orderedPair{
in: orderedRequest{p.in, p.in.oid},
out: orderedResponse{p.out, p.out.oid},
})
}
for _, p := range orderedPairs {
s.incomingPacket(p.in)
}
for _, p := range orderedPairs {
s.readyPacket(p.out)
}
for _, p := range table {
pkt := <-sender.sent
id := pkt.(orderedResponse).id()
assert.Equal(t, id, p.in.id())
}
}
s.close()
}
func (p sshFxpRemovePacket) String() string {
return fmt.Sprintf("RmPkt:%d", p.ID)
}
func (p sshFxpOpenPacket) String() string {
return fmt.Sprintf("OpPkt:%d", p.ID)
}
func (p sshFxpWritePacket) String() string {
return fmt.Sprintf("WrPkt:%d", p.ID)
}
func (p sshFxpClosePacket) String() string {
return fmt.Sprintf("ClPkt:%d", p.ID)
}
|