File: middleware.go

package info (click to toggle)
relic 7.6.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,108 kB
  • sloc: sh: 230; makefile: 10
file content (107 lines) | stat: -rw-r--r-- 2,548 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package compresshttp

import (
	"io"
	"log"
	"net/http"
)

func Middleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set(acceptEncoding, AcceptedEncodings)
		// decompress request
		if err := DecompressRequest(r); err == ErrUnacceptableEncoding {
			http.Error(w, "invalid content-encoding", http.StatusUnsupportedMediaType)
			return
		} else if err != nil {
			log.Printf("error: decoding request from %s: %+v", r.RemoteAddr, err)
			http.Error(w, "failed to decompress request", http.StatusBadRequest)
			return
		}
		// choose response encoding
		encoding := selectEncoding(r.Header.Get(acceptEncoding))
		if encoding == "" || encoding == EncodingIdentity {
			// shortcut if no encoding is possible
			next.ServeHTTP(w, r)
			return
		}
		w.Header().Del("Content-Length")
		// wrap writer in compression and call handler
		wrapped := &responseCompressor{
			rw:       w,
			encoding: encoding,
		}
		next.ServeHTTP(wrapped, r)
		// flush
		if err := wrapped.Close(); err != nil {
			log.Printf("error: flushing response to %s: %+v", r.RemoteAddr, err)
		}
	})
}

type responseCompressor struct {
	rw http.ResponseWriter
	wc io.WriteCloser

	encoding    string
	wroteHeader bool
}

func (w *responseCompressor) WriteHeader(status int) {
	if !w.wroteHeader {
		if status >= 300 {
			// don't compress errors
			w.encoding = ""
		} else if w.encoding != "" && w.encoding != EncodingIdentity {
			w.Header().Set(contentEncoding, w.encoding)
		}
		w.wroteHeader = true
	}
	w.rw.WriteHeader(status)
}

func (w *responseCompressor) Write(d []byte) (int, error) {
	// wait until the first byte to start compressing so that it can be
	// selectively disabled in the case of errors
	if w.wc == nil {
		if !w.wroteHeader {
			w.Header().Set(contentEncoding, w.encoding)
			w.rw.WriteHeader(http.StatusOK)
			w.wroteHeader = true
		}
		var err error
		w.wc, err = setupCompression(w.encoding, w.rw)
		if err != nil {
			return 0, err
		}
	}
	return w.wc.Write(d)
}

func (w *responseCompressor) Header() http.Header {
	return w.rw.Header()
}

func (w *responseCompressor) Flush() {
	// flush compressor
	if flusher, ok := w.wc.(flusher); ok {
		if err := flusher.Flush(); err != nil {
			log.Println("warning: flushing compressor:", err)
		}
	}
	// flush response
	if flusher, ok := w.rw.(http.Flusher); ok {
		flusher.Flush()
	}
}

func (w *responseCompressor) Close() error {
	if w.wc == nil {
		return nil
	}
	return w.wc.Close()
}

type flusher interface {
	Flush() error
}