1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
|
// Copyright 2009 The Go Authors. All rights reserved.
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rpc
import (
"fmt"
"net/http"
"reflect"
"strings"
)
// ----------------------------------------------------------------------------
// Codec
// ----------------------------------------------------------------------------
// Codec creates a CodecRequest to process each request.
type Codec interface {
NewRequest(*http.Request) CodecRequest
}
// CodecRequest decodes a request and encodes a response using a specific
// serialization scheme.
type CodecRequest interface {
// Reads request and returns the RPC method name.
Method() (string, error)
// Reads request filling the RPC method args.
ReadRequest(interface{}) error
// Writes response using the RPC method reply. The error parameter is
// the error returned by the method call, if any.
WriteResponse(http.ResponseWriter, interface{}, error) error
}
// ----------------------------------------------------------------------------
// Server
// ----------------------------------------------------------------------------
// NewServer returns a new RPC server.
func NewServer() *Server {
return &Server{
codecs: make(map[string]Codec),
services: new(serviceMap),
}
}
// RequestInfo contains all the information we pass to before/after functions
type RequestInfo struct {
Method string
Error error
Request *http.Request
StatusCode int
}
// Server serves registered RPC services using registered codecs.
type Server struct {
codecs map[string]Codec
services *serviceMap
interceptFunc func(i *RequestInfo) *http.Request
beforeFunc func(i *RequestInfo)
afterFunc func(i *RequestInfo)
}
// RegisterCodec adds a new codec to the server.
//
// Codecs are defined to process a given serialization scheme, e.g., JSON or
// XML. A codec is chosen based on the "Content-Type" header from the request,
// excluding the charset definition.
func (s *Server) RegisterCodec(codec Codec, contentType string) {
s.codecs[strings.ToLower(contentType)] = codec
}
// RegisterService adds a new service to the server.
//
// The name parameter is optional: if empty it will be inferred from
// the receiver type name.
//
// Methods from the receiver will be extracted if these rules are satisfied:
//
// - The receiver is exported (begins with an upper case letter) or local
// (defined in the package registering the service).
// - The method name is exported.
// - The method has three arguments: *http.Request, *args, *reply.
// - All three arguments are pointers.
// - The second and third arguments are exported or local.
// - The method has return type error.
//
// All other methods are ignored.
func (s *Server) RegisterService(receiver interface{}, name string) error {
return s.services.register(receiver, name, true)
}
// RegisterTCPService adds a new TCP service to the server.
// No HTTP request struct will be passed to the service methods.
//
// The name parameter is optional: if empty it will be inferred from
// the receiver type name.
//
// Methods from the receiver will be extracted if these rules are satisfied:
//
// - The receiver is exported (begins with an upper case letter) or local
// (defined in the package registering the service).
// - The method name is exported.
// - The method has two arguments: *args, *reply.
// - Both arguments are pointers.
// - Both arguments are exported or local.
// - The method has return type error.
//
// All other methods are ignored.
func (s *Server) RegisterTCPService(receiver interface{}, name string) error {
return s.services.register(receiver, name, false)
}
// HasMethod returns true if the given method is registered.
//
// The method uses a dotted notation as in "Service.Method".
func (s *Server) HasMethod(method string) bool {
if _, _, err := s.services.get(method); err == nil {
return true
}
return false
}
// RegisterInterceptFunc registers the specified function as the function
// that will be called before every request. The function is allowed to intercept
// the request e.g. add values to the context.
//
// Note: Only one function can be registered, subsequent calls to this
// method will overwrite all the previous functions.
func (s *Server) RegisterInterceptFunc(f func(i *RequestInfo) *http.Request) {
s.interceptFunc = f
}
// RegisterBeforeFunc registers the specified function as the function
// that will be called before every request.
//
// Note: Only one function can be registered, subsequent calls to this
// method will overwrite all the previous functions.
func (s *Server) RegisterBeforeFunc(f func(i *RequestInfo)) {
s.beforeFunc = f
}
// RegisterAfterFunc registers the specified function as the function
// that will be called after every request
//
// Note: Only one function can be registered, subsequent calls to this
// method will overwrite all the previous functions.
func (s *Server) RegisterAfterFunc(f func(i *RequestInfo)) {
s.afterFunc = f
}
// ServeHTTP
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
s.writeError(w, 405, "rpc: POST method required, received "+r.Method)
return
}
contentType := r.Header.Get("Content-Type")
idx := strings.Index(contentType, ";")
if idx != -1 {
contentType = contentType[:idx]
}
var codec Codec
if contentType == "" && len(s.codecs) == 1 {
// If Content-Type is not set and only one codec has been registered,
// then default to that codec.
for _, c := range s.codecs {
codec = c
}
} else if codec = s.codecs[strings.ToLower(contentType)]; codec == nil {
s.writeError(w, 415, "rpc: unrecognized Content-Type: "+contentType)
return
}
// Create a new codec request.
codecReq := codec.NewRequest(r)
// Get service method to be called.
method, errMethod := codecReq.Method()
if errMethod != nil {
s.writeError(w, 400, errMethod.Error())
return
}
serviceSpec, methodSpec, errGet := s.services.get(method)
if errGet != nil {
s.writeError(w, 400, errGet.Error())
return
}
// Decode the args.
args := reflect.New(methodSpec.argsType)
if errRead := codecReq.ReadRequest(args.Interface()); errRead != nil {
s.writeError(w, 400, errRead.Error())
return
}
// Call the registered Intercept Function
if s.interceptFunc != nil {
req := s.interceptFunc(&RequestInfo{
Request: r,
Method: method,
})
if req != nil {
r = req
}
}
// Call the registered Before Function
if s.beforeFunc != nil {
s.beforeFunc(&RequestInfo{
Request: r,
Method: method,
})
}
// Call the service method.
reply := reflect.New(methodSpec.replyType)
// omit the HTTP request if the service method doesn't accept it
var errValue []reflect.Value
if serviceSpec.passReq {
errValue = methodSpec.method.Func.Call([]reflect.Value{
serviceSpec.rcvr,
reflect.ValueOf(r),
args,
reply,
})
} else {
errValue = methodSpec.method.Func.Call([]reflect.Value{
serviceSpec.rcvr,
args,
reply,
})
}
// Cast the result to error if needed.
var errResult error
errInter := errValue[0].Interface()
if errInter != nil {
errResult = errInter.(error)
}
// Prevents Internet Explorer from MIME-sniffing a response away
// from the declared content-type
w.Header().Set("x-content-type-options", "nosniff")
// Encode the response.
if errWrite := codecReq.WriteResponse(w, reply.Interface(), errResult); errWrite != nil {
s.writeError(w, 400, errWrite.Error())
} else {
// Call the registered After Function
if s.afterFunc != nil {
s.afterFunc(&RequestInfo{
Request: r,
Method: method,
Error: errResult,
StatusCode: 200,
})
}
}
}
func (s *Server) writeError(w http.ResponseWriter, status int, msg string) {
w.WriteHeader(status)
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, msg)
if s.afterFunc != nil {
s.afterFunc(&RequestInfo{
Error: fmt.Errorf(msg),
StatusCode: status,
})
}
}
|