1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
|
// Copyright 2016-2018 Yubico AB
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build windows
// +build windows
package main
import (
"fmt"
"sync"
"unsafe"
log "github.com/sirupsen/logrus"
)
// #cgo CFLAGS: -DUNICODE -D_UNICODE
// #cgo LDFLAGS: -lwinusb -lsetupapi -luuid
// #include "usb_windows.h"
import "C"
var device struct {
ctx C.PDEVICE_CONTEXT
mtx sync.Mutex
}
type C_DWORD C.DWORD
func (e C_DWORD) Error() string {
return fmt.Sprintf("Windows Error: 0x%x", uint(e))
}
const (
SUCCESS C_DWORD = C.ERROR_SUCCESS
ERROR_INVALID_STATE C_DWORD = C.ERROR_INVALID_STATE
ERROR_INVALID_HANDLE C_DWORD = C.ERROR_INVALID_HANDLE
ERROR_INVALID_PARAMETER C_DWORD = C.ERROR_INVALID_PARAMETER
ERROR_OUTOFMEMORY C_DWORD = C.ERROR_OUTOFMEMORY
ERROR_GEN_FAILURE C_DWORD = C.ERROR_GEN_FAILURE
ERROR_OBJECT_NOT_FOUND C_DWORD = C.ERROR_OBJECT_NOT_FOUND
ERROR_NOT_SUPPORTED C_DWORD = C.ERROR_NOT_SUPPORTED
ERROR_SHARING_VIOLATION C_DWORD = C.ERROR_SHARING_VIOLATION
ERROR_BAD_COMMAND C_DWORD = C.ERROR_BAD_COMMAND
)
func winusbError(err C.DWORD) error {
if err != C.ERROR_SUCCESS {
return C_DWORD(err)
}
return nil
}
func usbopen(cid string, serial string) (err error) {
if device.ctx != nil {
log.WithField("Correlation-ID", cid).Debug("usb context already open")
return nil
}
if serial != "" {
cSerial := C.CString(serial)
defer C.free(unsafe.Pointer(cSerial))
err = winusbError(C.usbOpen(0x1050, 0x0030, cSerial, &device.ctx))
} else {
err = winusbError(C.usbOpen(0x1050, 0x0030, nil, &device.ctx))
}
if device.ctx == nil {
err = fmt.Errorf("device not found")
}
return err
}
func usbclose(cid string) {
if device.ctx != nil {
C.usbClose(&device.ctx)
}
}
func usbreopen(cid string, why error, serial string) (err error) {
log.WithFields(log.Fields{
"Correlation-ID": cid,
"why": why,
}).Debug("reopening usb context")
usbclose(cid)
return usbopen(cid, serial)
}
func usbCheck(cid string, serial string) (err error) {
device.mtx.Lock()
defer device.mtx.Unlock()
if err = usbopen(cid, serial); err != nil {
return err
}
for {
if err = winusbError(C.usbCheck(device.ctx, 0x1050, 0x0030)); err != nil {
log.WithFields(log.Fields{
"Correlation-ID": cid,
"Error": err,
}).Debug("Couldn't check usb context")
if err = usbreopen(cid, err, serial); err != nil {
return err
}
continue
}
break
}
return nil
}
func usbwrite(buf []byte, cid string) (err error) {
var n C.ULONG
if err = winusbError(C.usbWrite(
device.ctx,
(*C.UCHAR)(unsafe.Pointer(&buf[0])),
C.ULONG(len(buf)),
&n)); err != nil {
goto out
}
out:
log.WithFields(log.Fields{
"Correlation-ID": cid,
"n": uint(n),
"err": err,
"len": len(buf),
"buf": buf,
}).Debug("usb endpoint write")
return err
}
func usbread(cid string) (buf []byte, err error) {
var n C.ULONG
buf = make([]byte, 8192)
if err = winusbError(C.usbRead(
device.ctx,
(*C.UCHAR)(unsafe.Pointer(&buf[0])),
C.ULONG(len(buf)),
&n)); err != nil {
buf = buf[:0]
goto out
}
buf = buf[:n]
out:
log.WithFields(log.Fields{
"Correlation-ID": cid,
"n": uint(n),
"err": err,
"len": len(buf),
"buf": buf,
}).Debug("usb endpoint read")
return buf, err
}
func usbProxy(req []byte, cid string, serial string) (resp []byte, err error) {
device.mtx.Lock()
defer device.mtx.Unlock()
if err = usbopen(cid, serial); err != nil {
return nil, err
}
for i := 0; i < 2; i++ {
if err = usbwrite(req, cid); err != nil {
if err2 := usbreopen(cid, err, serial); err2 != nil {
return nil, err2
}
continue
}
resp, err = usbread(cid)
break
}
return resp, err
}
|