File: prehook.go

package info (click to toggle)
pat 0.19.2-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,228 kB
  • sloc: javascript: 3,864; sh: 147; makefile: 11
file content (163 lines) | stat: -rw-r--r-- 4,183 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
// Package prehook implements a connection prehook mechanism, to handle any
// pre-negotiation required by a remote node before the B2F protocol can
// commence (e.g. packet node traversal).
package prehook

import (
	"bufio"
	"context"
	"errors"
	"fmt"
	"io"
	"net"
	"os"
	"os/exec"
	"time"

	"github.com/la5nta/pat/internal/debug"
	"golang.org/x/sync/errgroup"
)

var ErrConnNotWrapped = errors.New("connection not wrapped for prehook")

type Script struct {
	File string
	Args []string
	Env  []string
}

// Execute executes the prehook script on a wrapped connection.
//
// ErrConnNotWrapped is returned if conn is not wrapped.
func (s Script) Execute(ctx context.Context, conn net.Conn) error {
	if conn, ok := conn.(*Conn); ok {
		return conn.Execute(ctx, s)
	}
	return ErrConnNotWrapped
}

type Conn struct {
	net.Conn
	br *bufio.Reader
}

// Verify returns nil if the given script file is found and valid.
func Verify(file string) error {
	_, err := exec.LookPath(file)
	if errors.Is(err, exec.ErrDot) {
		err = nil
	}
	return err
}

// Wrap returns a wrapped connection with the ability to execute a prehook.
//
// The returned Conn implements the net.Conn interface, and should be used in
// place of the original throughout the lifetime of the connection once the
// prehook script is executed.
func Wrap(conn net.Conn) *Conn {
	return &Conn{
		Conn: conn,
		br:   bufio.NewReader(conn),
	}
}

func (p *Conn) Read(b []byte) (int, error) { return p.br.Read(b) }

// Execute executes the prehook script, returning nil if the process
// terminated successfully (exit code 0).
func (p *Conn) Execute(ctx context.Context, script Script) error {
	cmd := exec.CommandContext(ctx, script.File, script.Args...)
	cmd.Env = script.Env
	cmd.Stderr = os.Stderr
	cmd.Stdout = p.Conn
	cmdStdin, err := cmd.StdinPipe()
	if err != nil {
		return err
	}

	debugf("start cmd: %s", cmd)
	if err := cmd.Start(); err != nil {
		return err
	}

	g, ctx := errgroup.WithContext(ctx)
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()
	g.Go(func() error { return forwardLines(ctx, cmdStdin, p.br) })
	g.Go(func() error { defer cancel(); return cmd.Wait() })
	return g.Wait()
}

// forwardLines forwards data from to the spawned process line by line.
//
// The line delimiter is CR or LF, but to facilitate scripting we forward
// each line with LF ending only.
func forwardLines(ctx context.Context, w io.Writer, r *bufio.Reader) error {
	// Copy the lines to stdout so the user can see what's going on.
	stdinBuffered := bufio.NewWriter(io.MultiWriter(w, os.Stdout))
	defer stdinBuffered.Flush()

	isDelimiter := func(b byte) bool { return b == '\n' || b == '\r' }

	var isPrefix bool // true if we're in the middle of a line
	for {
		if !isPrefix {
			// Peek until the next new line (discard empty lines).
			debugf("wait next line")
			switch peek, err := r.Peek(1); {
			case err != nil:
				// Connection lost.
				debugf("connection lost while waiting for next line")
				return err
			case len(peek) > 0 && isDelimiter(peek[0]):
				debugf("discard %q", peek)
				r.Discard(1)
				continue
			case ctx.Err() != nil:
				// Child process exited before the next line
				// arrived. We're done.
				debugf("cmd exited while waiting for next line")
				return nil
			default:
				debugf("at next line")
			}
		}

		// Read and forward the byte.
		// Replace CR with LF for convenience.
		b, err := r.ReadByte()
		if err != nil {
			// Connection lost.
			debugf("connection lost while reading next byte")
			return err
		}
		if b == '\r' {
			b = '\n'
		}
		stdinBuffered.WriteByte(b)

		isPrefix = !isDelimiter(b)
		if isPrefix {
			// Keep going. We're in the middle of a line.
			continue
		}

		// A line was just terminated.
		// Flush and wait a bit to check if the process exits.
		if err := stdinBuffered.Flush(); err != nil {
			return fmt.Errorf("child process exited prematurely: %w", err)
		}
		select {
		case <-time.After(100 * time.Millisecond):
			// Child process is still alive. Keep going.
		case <-ctx.Done():
			// Child process exited. We're done.
			return nil
		}
	}
}

func debugf(format string, args ...interface{}) {
	debug.Printf("prehook: "+format, args...)
}