1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
|
package tpm
import (
"context"
"errors"
"fmt"
"io"
"math"
"github.com/google/go-tpm/legacy/tpm2"
)
type ShortRandomReadError struct {
Requested int
Generated int
}
func (s ShortRandomReadError) Error() string {
return fmt.Sprintf("generated %d random bytes instead of the requested %d", s.Generated, s.Requested)
}
// GenerateRandom returns `size` number of random bytes generated by the TPM.
func (t *TPM) GenerateRandom(ctx context.Context, size uint16) (random []byte, err error) {
if err = t.open(goTPMCall(ctx)); err != nil {
return nil, fmt.Errorf("failed opening TPM: %w", err)
}
defer closeTPM(ctx, t, &err)
return t.generateRandom(ctx, size)
}
func (t *TPM) generateRandom(_ context.Context, size uint16) (random []byte, err error) {
random, err = tpm2.GetRandom(t.rwc, size)
if err != nil {
return nil, fmt.Errorf("failed generating random data: %w", err)
}
if len(random) != int(size) {
return nil, ShortRandomReadError{Requested: int(size), Generated: len(random)}
}
return
}
type generator struct {
t *TPM
readError error
}
func (t *TPM) RandomReader() (io.Reader, error) {
return &generator{
t: t,
}, nil
}
func (g *generator) Read(p []byte) (n int, err error) {
if g.readError != nil {
errMsg := g.readError.Error() // multiple wrapped errors not (yet) allowed
return 0, fmt.Errorf("failed generating random bytes in previous call to Read: %s: %w", errMsg, io.EOF)
}
if len(p) > math.MaxUint16 {
p = p[:math.MaxUint16]
}
ctx := context.Background()
if err = g.t.open(goTPMCall(ctx)); err != nil {
return 0, fmt.Errorf("failed opening TPM: %w", err)
}
defer closeTPM(ctx, g.t, &err)
var result []byte
requestedLength := len(p)
singleRequestLength := uint16(requestedLength)
for len(result) < requestedLength {
if r, err := g.t.generateRandom(ctx, singleRequestLength); err == nil {
result = append(result, r...)
} else {
var s ShortRandomReadError
if errors.As(err, &s) && s.Generated > 0 {
// adjust number of bytes to request if at least some data was read and continue loop
singleRequestLength = uint16(s.Generated)
result = append(result, r...)
} else {
g.readError = err // store the error to be returned for future calls to Read
n = copy(p, result)
return n, nil // return the result recorded so far and no error
}
}
}
n = copy(p, result)
return
}
var _ io.Reader = (*generator)(nil)
|