1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
|
package prefork
import (
"errors"
"log"
"net"
"os"
"os/exec"
"runtime"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/reuseport"
)
const (
preforkChildEnvVariable = "FASTHTTP_PREFORK_CHILD"
defaultNetwork = "tcp4"
)
var (
defaultLogger = Logger(log.New(os.Stderr, "", log.LstdFlags))
// ErrOverRecovery is returned when the times of starting over child prefork processes exceed
// the threshold.
ErrOverRecovery = errors.New("exceeding the value of RecoverThreshold")
// ErrOnlyReuseportOnWindows is returned when Reuseport is false.
ErrOnlyReuseportOnWindows = errors.New("windows only supports Reuseport = true")
)
// Logger is used for logging formatted messages.
type Logger interface {
// Printf must have the same semantics as log.Printf.
Printf(format string, args ...any)
}
// Prefork implements fasthttp server prefork.
//
// Preforks master process (with all cores) between several child processes
// increases performance significantly, because Go doesn't have to share
// and manage memory between cores.
//
// WARNING: using prefork prevents the use of any global state!
// Things like in-memory caches won't work.
type Prefork struct {
// By default standard logger from log package is used.
Logger Logger
ln net.Listener
ServeFunc func(ln net.Listener) error
ServeTLSFunc func(ln net.Listener, certFile, keyFile string) error
ServeTLSEmbedFunc func(ln net.Listener, certData, keyData []byte) error
// The network must be "tcp", "tcp4" or "tcp6".
//
// By default is "tcp4"
Network string
files []*os.File
// Child prefork processes may exit with failure and will be started over until the times reach
// the value of RecoverThreshold, then it will return and terminate the server.
RecoverThreshold int
// Flag to use a listener with reuseport, if not a file Listener will be used
// See: https://www.nginx.com/blog/socket-sharding-nginx-release-1-9-1/
//
// It's disabled by default
Reuseport bool
}
// IsChild checks if the current thread/process is a child.
func IsChild() bool {
return os.Getenv(preforkChildEnvVariable) == "1"
}
// New wraps the fasthttp server to run with preforked processes.
func New(s *fasthttp.Server) *Prefork {
return &Prefork{
Network: defaultNetwork,
RecoverThreshold: runtime.GOMAXPROCS(0) / 2,
Logger: s.Logger,
ServeFunc: s.Serve,
ServeTLSFunc: s.ServeTLS,
ServeTLSEmbedFunc: s.ServeTLSEmbed,
}
}
func (p *Prefork) logger() Logger {
if p.Logger != nil {
return p.Logger
}
return defaultLogger
}
func (p *Prefork) listen(addr string) (net.Listener, error) {
runtime.GOMAXPROCS(1)
if p.Network == "" {
p.Network = defaultNetwork
}
if p.Reuseport {
return reuseport.Listen(p.Network, addr)
}
return net.FileListener(os.NewFile(3, ""))
}
func (p *Prefork) setTCPListenerFiles(addr string) error {
if p.Network == "" {
p.Network = defaultNetwork
}
tcpAddr, err := net.ResolveTCPAddr(p.Network, addr)
if err != nil {
return err
}
tcplistener, err := net.ListenTCP(p.Network, tcpAddr)
if err != nil {
return err
}
p.ln = tcplistener
fl, err := tcplistener.File()
if err != nil {
return err
}
p.files = []*os.File{fl}
return nil
}
func (p *Prefork) doCommand() (*exec.Cmd, error) {
// #nosec G204
cmd := exec.Command(os.Args[0], os.Args[1:]...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(), preforkChildEnvVariable+"=1")
cmd.ExtraFiles = p.files
err := cmd.Start()
return cmd, err
}
func (p *Prefork) prefork(addr string) (err error) {
if !p.Reuseport {
if runtime.GOOS == "windows" {
return ErrOnlyReuseportOnWindows
}
if err = p.setTCPListenerFiles(addr); err != nil {
return
}
// defer for closing the net.Listener opened by setTCPListenerFiles.
defer func() {
e := p.ln.Close()
if err == nil {
err = e
}
}()
}
type procSig struct {
err error
pid int
}
goMaxProcs := runtime.GOMAXPROCS(0)
sigCh := make(chan procSig, goMaxProcs)
childProcs := make(map[int]*exec.Cmd)
defer func() {
for _, proc := range childProcs {
_ = proc.Process.Kill()
}
}()
for i := 0; i < goMaxProcs; i++ {
var cmd *exec.Cmd
if cmd, err = p.doCommand(); err != nil {
p.logger().Printf("failed to start a child prefork process, error: %v\n", err)
return
}
childProcs[cmd.Process.Pid] = cmd
go func() {
sigCh <- procSig{pid: cmd.Process.Pid, err: cmd.Wait()}
}()
}
var exitedProcs int
for sig := range sigCh {
delete(childProcs, sig.pid)
p.logger().Printf("one of the child prefork processes exited with "+
"error: %v", sig.err)
exitedProcs++
if exitedProcs > p.RecoverThreshold {
p.logger().Printf("child prefork processes exit too many times, "+
"which exceeds the value of RecoverThreshold(%d), "+
"exiting the master process.\n", exitedProcs)
err = ErrOverRecovery
break
}
var cmd *exec.Cmd
if cmd, err = p.doCommand(); err != nil {
break
}
childProcs[cmd.Process.Pid] = cmd
go func() {
sigCh <- procSig{pid: cmd.Process.Pid, err: cmd.Wait()}
}()
}
return
}
// ListenAndServe serves HTTP requests from the given TCP addr.
func (p *Prefork) ListenAndServe(addr string) error {
if IsChild() {
ln, err := p.listen(addr)
if err != nil {
return err
}
p.ln = ln
return p.ServeFunc(ln)
}
return p.prefork(addr)
}
// ListenAndServeTLS serves HTTPS requests from the given TCP addr.
//
// certFile and keyFile are paths to TLS certificate and key files.
func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error {
if IsChild() {
ln, err := p.listen(addr)
if err != nil {
return err
}
p.ln = ln
return p.ServeTLSFunc(ln, certFile, certKey)
}
return p.prefork(addr)
}
// ListenAndServeTLSEmbed serves HTTPS requests from the given TCP addr.
//
// certData and keyData must contain valid TLS certificate and key data.
func (p *Prefork) ListenAndServeTLSEmbed(addr string, certData, keyData []byte) error {
if IsChild() {
ln, err := p.listen(addr)
if err != nil {
return err
}
p.ln = ln
return p.ServeTLSEmbedFunc(ln, certData, keyData)
}
return p.prefork(addr)
}
|