1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
|
// Package prehook implements a connection prehook mechanism, to handle any
// pre-negotiation required by a remote node before the B2F protocol can
// commence (e.g. packet node traversal).
package prehook
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"time"
"github.com/la5nta/pat/internal/debug"
"golang.org/x/sync/errgroup"
)
var ErrConnNotWrapped = errors.New("connection not wrapped for prehook")
type Script struct {
File string
Args []string
Env []string
}
// Execute executes the prehook script on a wrapped connection.
//
// ErrConnNotWrapped is returned if conn is not wrapped.
func (s Script) Execute(ctx context.Context, conn net.Conn) error {
if conn, ok := conn.(*Conn); ok {
return conn.Execute(ctx, s)
}
return ErrConnNotWrapped
}
type Conn struct {
net.Conn
br *bufio.Reader
}
// Verify returns nil if the given script file is found and valid.
func Verify(file string) error {
_, err := exec.LookPath(file)
if errors.Is(err, exec.ErrDot) {
err = nil
}
return err
}
// Wrap returns a wrapped connection with the ability to execute a prehook.
//
// The returned Conn implements the net.Conn interface, and should be used in
// place of the original throughout the lifetime of the connection once the
// prehook script is executed.
func Wrap(conn net.Conn) *Conn {
return &Conn{
Conn: conn,
br: bufio.NewReader(conn),
}
}
func (p *Conn) Read(b []byte) (int, error) { return p.br.Read(b) }
// Execute executes the prehook script, returning nil if the process
// terminated successfully (exit code 0).
func (p *Conn) Execute(ctx context.Context, script Script) error {
cmd := exec.CommandContext(ctx, script.File, script.Args...)
cmd.Env = script.Env
cmd.Stderr = os.Stderr
cmd.Stdout = p.Conn
cmdStdin, err := cmd.StdinPipe()
if err != nil {
return err
}
debugf("start cmd: %s", cmd)
if err := cmd.Start(); err != nil {
return err
}
g, ctx := errgroup.WithContext(ctx)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
g.Go(func() error { return forwardLines(ctx, cmdStdin, p.br) })
g.Go(func() error { defer cancel(); return cmd.Wait() })
return g.Wait()
}
// forwardLines forwards data from to the spawned process line by line.
//
// The line delimiter is CR or LF, but to facilitate scripting we forward
// each line with LF ending only.
func forwardLines(ctx context.Context, w io.Writer, r *bufio.Reader) error {
// Copy the lines to stdout so the user can see what's going on.
stdinBuffered := bufio.NewWriter(io.MultiWriter(w, os.Stdout))
defer stdinBuffered.Flush()
isDelimiter := func(b byte) bool { return b == '\n' || b == '\r' }
var isPrefix bool // true if we're in the middle of a line
for {
if !isPrefix {
// Peek until the next new line (discard empty lines).
debugf("wait next line")
switch peek, err := r.Peek(1); {
case err != nil:
// Connection lost.
debugf("connection lost while waiting for next line")
return err
case len(peek) > 0 && isDelimiter(peek[0]):
debugf("discard %q", peek)
r.Discard(1)
continue
case ctx.Err() != nil:
// Child process exited before the next line
// arrived. We're done.
debugf("cmd exited while waiting for next line")
return nil
default:
debugf("at next line")
}
}
// Read and forward the byte.
// Replace CR with LF for convenience.
b, err := r.ReadByte()
if err != nil {
// Connection lost.
debugf("connection lost while reading next byte")
return err
}
if b == '\r' {
b = '\n'
}
stdinBuffered.WriteByte(b)
isPrefix = !isDelimiter(b)
if isPrefix {
// Keep going. We're in the middle of a line.
continue
}
// A line was just terminated.
// Flush and wait a bit to check if the process exits.
if err := stdinBuffered.Flush(); err != nil {
return fmt.Errorf("child process exited prematurely: %w", err)
}
select {
case <-time.After(100 * time.Millisecond):
// Child process is still alive. Keep going.
case <-ctx.Done():
// Child process exited. We're done.
return nil
}
}
}
func debugf(format string, args ...interface{}) {
debug.Printf("prehook: "+format, args...)
}
|