1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
|
package socket
import (
"bufio"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"syscall"
"time"
"unsafe"
"github.com/digitalocean/go-libvirt/internal/constants"
)
const disconnectTimeout = 5 * time.Second
// request and response statuses
const (
// StatusOK is always set for method calls or events.
// For replies it indicates successful completion of the method.
// For streams it indicates confirmation of the end of file on the stream.
StatusOK = iota
// StatusError for replies indicates that the method call failed
// and error information is being returned. For streams this indicates
// that not all data was sent and the stream has aborted.
StatusError
// StatusContinue is only used for streams.
// This indicates that further data packets will be following.
StatusContinue
)
// request and response types
const (
// Call is used when making calls to the remote server.
Call = iota
// Reply indicates a server reply.
Reply
// Message is an asynchronous notification.
Message
// Stream represents a stream data packet.
Stream
// CallWithFDs is used by a client to indicate the request has
// arguments with file descriptors.
CallWithFDs
// ReplyWithFDs is used by a server to indicate the request has
// arguments with file descriptors.
ReplyWithFDs
)
// Dialer is an interface for connecting to libvirt's underlying socket.
type Dialer interface {
Dial() (net.Conn, error)
}
// Router is an interface used to route packets to the appropriate clients.
type Router interface {
Route(*Header, []byte)
}
// Socket represents a libvirt Socket and its connection state
type Socket struct {
dialer Dialer
router Router
conn net.Conn
reader *bufio.Reader
writer *bufio.Writer
// used to serialize any Socket writes and any updates to conn, r, or w
mu *sync.Mutex
// disconnected is closed when the listen goroutine associated with a
// Socket connection has returned.
disconnected chan struct{}
}
// packet represents a RPC request or response.
type packet struct {
// Size of packet, in bytes, including length.
// Len + Header + Payload
Len uint32
Header Header
}
// Global packet instance, for use with unsafe.Sizeof()
var _p packet
// Header is a libvirt rpc packet header
type Header struct {
// Program identifier
Program uint32
// Program version
Version uint32
// Remote procedure identifier
Procedure uint32
// Call type, e.g., Reply
Type uint32
// Call serial number
Serial int32
// Request status, e.g., StatusOK
Status uint32
}
// New initializes a new type for managing the Socket.
func New(dialer Dialer, router Router) *Socket {
s := &Socket{
dialer: dialer,
router: router,
disconnected: make(chan struct{}),
mu: &sync.Mutex{},
}
// we start with a closed channel since that indicates no connection
close(s.disconnected)
return s
}
// Connect uses the dialer provided on creation to establish
// underlying physical connection to the desired libvirt.
func (s *Socket) Connect() error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.isDisconnected() {
return errors.New("already connected to socket")
}
conn, err := s.dialer.Dial()
if err != nil {
return err
}
s.conn = conn
s.reader = bufio.NewReader(conn)
s.writer = bufio.NewWriter(conn)
s.disconnected = make(chan struct{})
go s.listenAndRoute()
return nil
}
// Disconnect closes the Socket connection to libvirt and waits for the reader
// gorouting to shut down.
func (s *Socket) Disconnect() error {
// just return if we're already disconnected
if s.isDisconnected() {
return nil
}
err := s.conn.Close()
if err != nil {
return err
}
// now we wait for the reader to return so as not to avoid it nil
// referencing
// Put this in a select,
// and have it only nil out the conn value if it doesn't fail
select {
case <-s.disconnected:
case <-time.After(disconnectTimeout):
return errors.New("timed out waiting for Disconnect cleanup")
}
return nil
}
// Disconnected returns a channel that will be closed once the current
// connection is closed. This can happen due to an explicit call to Disconnect
// from the client, or due to non-temporary Read or Write errors encountered.
func (s *Socket) Disconnected() <-chan struct{} {
return s.disconnected
}
// isDisconnected is a non-blocking function to query whether a connection
// is disconnected or not.
func (s *Socket) isDisconnected() bool {
select {
case <-s.disconnected:
return true
default:
return false
}
}
// listenAndRoute reads packets from the Socket and calls the provided
// Router function to route them
func (s *Socket) listenAndRoute() {
// only returns once it detects a non-temporary error related to the
// underlying connection
listen(s.reader, s.router)
// signal any clients listening that the connection has been disconnected
close(s.disconnected)
}
// listen processes incoming data and routes
// responses to their respective callback handler.
func listen(s io.Reader, router Router) {
for {
// response packet length
length, err := pktlen(s)
if err != nil {
if isTemporary(err) {
continue
}
// connection is no longer valid, so shutdown
return
}
// response header
h, err := extractHeader(s)
if err != nil {
// invalid packet
continue
}
// payload: packet length minus what was previously read
size := int(length) - int(unsafe.Sizeof(_p))
buf := make([]byte, size)
_, err = io.ReadFull(s, buf)
if err != nil {
// invalid packet
continue
}
// route response to caller
router.Route(h, buf)
}
}
// isTemporary returns true if the error returned from a read is transient.
// If the error type is an OpError, check whether the net connection
// error condition is temporary (which means we can keep using the
// connection).
// Errors not of the net.OpError type tend to be things like io.EOF,
// syscall.EINVAL, or io.ErrClosedPipe (i.e. all things that
// indicate the connection in use is no longer valid.)
func isTemporary(err error) bool {
opErr, ok := err.(*net.OpError)
if ok {
return opErr.Temporary()
}
return false
}
// pktlen returns the length of an incoming RPC packet. Read errors will
// result in a returned response length of 0 and a non-nil error.
func pktlen(r io.Reader) (uint32, error) {
buf := make([]byte, unsafe.Sizeof(_p.Len))
// extract the packet's length from the header
_, err := io.ReadFull(r, buf)
if err != nil {
return 0, err
}
return binary.BigEndian.Uint32(buf), nil
}
// extractHeader returns the decoded header from an incoming response.
func extractHeader(r io.Reader) (*Header, error) {
buf := make([]byte, unsafe.Sizeof(_p.Header))
// extract the packet's header from r
_, err := io.ReadFull(r, buf)
if err != nil {
return nil, err
}
return &Header{
Program: binary.BigEndian.Uint32(buf[0:4]),
Version: binary.BigEndian.Uint32(buf[4:8]),
Procedure: binary.BigEndian.Uint32(buf[8:12]),
Type: binary.BigEndian.Uint32(buf[12:16]),
Serial: int32(binary.BigEndian.Uint32(buf[16:20])),
Status: binary.BigEndian.Uint32(buf[20:24]),
}, nil
}
// SendPacket sends a packet to libvirt on the socket connection.
func (s *Socket) SendPacket(
serial int32,
proc uint32,
program uint32,
payload []byte,
typ uint32,
status uint32,
) error {
p := packet{
Header: Header{
Program: program,
Version: constants.ProtocolVersion,
Procedure: proc,
Type: typ,
Serial: serial,
Status: status,
},
}
size := int(unsafe.Sizeof(p.Len)) + int(unsafe.Sizeof(p.Header))
if payload != nil {
size += len(payload)
}
p.Len = uint32(size)
if s.isDisconnected() {
// this mirrors what a lot of net code return on use of a no
// longer valid connection
return syscall.EINVAL
}
s.mu.Lock()
defer s.mu.Unlock()
err := binary.Write(s.writer, binary.BigEndian, p)
if err != nil {
return err
}
// write payload
if payload != nil {
err = binary.Write(s.writer, binary.BigEndian, payload)
if err != nil {
return err
}
}
return s.writer.Flush()
}
// SendStream sends a stream of packets to libvirt on the socket connection.
func (s *Socket) SendStream(serial int32, proc uint32, program uint32,
stream io.Reader, abort chan bool) error {
// Keep total packet length under 4 MiB to follow possible limitation in libvirt server code
buf := make([]byte, 4*MiB-unsafe.Sizeof(_p))
for {
select {
case <-abort:
return s.SendPacket(serial, proc, program, nil, Stream, StatusError)
default:
}
n, err := stream.Read(buf)
if n > 0 {
err2 := s.SendPacket(serial, proc, program, buf[:n], Stream, StatusContinue)
if err2 != nil {
return err2
}
}
if err != nil {
if err == io.EOF {
return s.SendPacket(serial, proc, program, nil, Stream, StatusOK)
}
// keep original error
err2 := s.SendPacket(serial, proc, program, nil, Stream, StatusError)
if err2 != nil {
return err2
}
return err
}
}
}
|