1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
|
package cipher
import (
"crypto/aes"
"crypto/cipher"
"fmt"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwe/internal/aescbc"
"github.com/lestrrat-go/jwx/v2/jwe/internal/keygen"
)
var gcm = &gcmFetcher{}
var cbc = &cbcFetcher{}
func (f gcmFetcher) Fetch(key []byte) (cipher.AEAD, error) {
aescipher, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf(`cipher: failed to create AES cipher for GCM: %w`, err)
}
aead, err := cipher.NewGCM(aescipher)
if err != nil {
return nil, fmt.Errorf(`failed to create GCM for cipher: %w`, err)
}
return aead, nil
}
func (f cbcFetcher) Fetch(key []byte) (cipher.AEAD, error) {
aead, err := aescbc.New(key, aes.NewCipher)
if err != nil {
return nil, fmt.Errorf(`cipher: failed to create AES cipher for CBC: %w`, err)
}
return aead, nil
}
func (c AesContentCipher) KeySize() int {
return c.keysize
}
func (c AesContentCipher) TagSize() int {
return c.tagsize
}
func NewAES(alg jwa.ContentEncryptionAlgorithm) (*AesContentCipher, error) {
var keysize int
var tagsize int
var fetcher Fetcher
switch alg {
case jwa.A128GCM:
keysize = 16
tagsize = 16
fetcher = gcm
case jwa.A192GCM:
keysize = 24
tagsize = 16
fetcher = gcm
case jwa.A256GCM:
keysize = 32
tagsize = 16
fetcher = gcm
case jwa.A128CBC_HS256:
tagsize = 16
keysize = tagsize * 2
fetcher = cbc
case jwa.A192CBC_HS384:
tagsize = 24
keysize = tagsize * 2
fetcher = cbc
case jwa.A256CBC_HS512:
tagsize = 32
keysize = tagsize * 2
fetcher = cbc
default:
return nil, fmt.Errorf("failed to create AES content cipher: invalid algorithm (%s)", alg)
}
return &AesContentCipher{
keysize: keysize,
tagsize: tagsize,
fetch: fetcher,
}, nil
}
func (c AesContentCipher) Encrypt(cek, plaintext, aad []byte) (iv, ciphertxt, tag []byte, err error) {
var aead cipher.AEAD
aead, err = c.fetch.Fetch(cek)
if err != nil {
return nil, nil, nil, fmt.Errorf(`failed to fetch AEAD: %w`, err)
}
// Seal may panic (argh!), so protect ourselves from that
defer func() {
if e := recover(); e != nil {
switch e := e.(type) {
case error:
err = e
default:
err = fmt.Errorf("%s", e)
}
err = fmt.Errorf(`failed to encrypt: %w`, err)
}
}()
var bs keygen.ByteSource
if c.NonceGenerator == nil {
bs, err = keygen.NewRandom(aead.NonceSize()).Generate()
} else {
bs, err = c.NonceGenerator.Generate()
}
if err != nil {
return nil, nil, nil, fmt.Errorf(`failed to generate nonce: %w`, err)
}
iv = bs.Bytes()
combined := aead.Seal(nil, iv, plaintext, aad)
tagoffset := len(combined) - c.TagSize()
if tagoffset < 0 {
panic(fmt.Sprintf("tag offset is less than 0 (combined len = %d, tagsize = %d)", len(combined), c.TagSize()))
}
tag = combined[tagoffset:]
ciphertxt = make([]byte, tagoffset)
copy(ciphertxt, combined[:tagoffset])
return
}
func (c AesContentCipher) Decrypt(cek, iv, ciphertxt, tag, aad []byte) (plaintext []byte, err error) {
aead, err := c.fetch.Fetch(cek)
if err != nil {
return nil, fmt.Errorf(`failed to fetch AEAD data: %w`, err)
}
// Open may panic (argh!), so protect ourselves from that
defer func() {
if e := recover(); e != nil {
switch e := e.(type) {
case error:
err = e
default:
err = fmt.Errorf(`%s`, e)
}
err = fmt.Errorf(`failed to decrypt: %w`, err)
return
}
}()
combined := make([]byte, len(ciphertxt)+len(tag))
copy(combined, ciphertxt)
copy(combined[len(ciphertxt):], tag)
buf, aeaderr := aead.Open(nil, iv, combined, aad)
if aeaderr != nil {
err = fmt.Errorf(`aead.Open failed: %w`, aeaderr)
return
}
plaintext = buf
return
}
|