File: cipher.go

package info (click to toggle)
golang-github-lestrrat-go-jwx 2.1.4-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,872 kB
  • sloc: sh: 222; makefile: 86; perl: 62
file content (161 lines) | stat: -rw-r--r-- 3,659 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
package cipher

import (
	"crypto/aes"
	"crypto/cipher"
	"fmt"

	"github.com/lestrrat-go/jwx/v2/jwa"
	"github.com/lestrrat-go/jwx/v2/jwe/internal/aescbc"
	"github.com/lestrrat-go/jwx/v2/jwe/internal/keygen"
)

var gcm = &gcmFetcher{}
var cbc = &cbcFetcher{}

func (f gcmFetcher) Fetch(key []byte) (cipher.AEAD, error) {
	aescipher, err := aes.NewCipher(key)
	if err != nil {
		return nil, fmt.Errorf(`cipher: failed to create AES cipher for GCM: %w`, err)
	}

	aead, err := cipher.NewGCM(aescipher)
	if err != nil {
		return nil, fmt.Errorf(`failed to create GCM for cipher: %w`, err)
	}
	return aead, nil
}

func (f cbcFetcher) Fetch(key []byte) (cipher.AEAD, error) {
	aead, err := aescbc.New(key, aes.NewCipher)
	if err != nil {
		return nil, fmt.Errorf(`cipher: failed to create AES cipher for CBC: %w`, err)
	}
	return aead, nil
}

func (c AesContentCipher) KeySize() int {
	return c.keysize
}

func (c AesContentCipher) TagSize() int {
	return c.tagsize
}

func NewAES(alg jwa.ContentEncryptionAlgorithm) (*AesContentCipher, error) {
	var keysize int
	var tagsize int
	var fetcher Fetcher
	switch alg {
	case jwa.A128GCM:
		keysize = 16
		tagsize = 16
		fetcher = gcm
	case jwa.A192GCM:
		keysize = 24
		tagsize = 16
		fetcher = gcm
	case jwa.A256GCM:
		keysize = 32
		tagsize = 16
		fetcher = gcm
	case jwa.A128CBC_HS256:
		tagsize = 16
		keysize = tagsize * 2
		fetcher = cbc
	case jwa.A192CBC_HS384:
		tagsize = 24
		keysize = tagsize * 2
		fetcher = cbc
	case jwa.A256CBC_HS512:
		tagsize = 32
		keysize = tagsize * 2
		fetcher = cbc
	default:
		return nil, fmt.Errorf("failed to create AES content cipher: invalid algorithm (%s)", alg)
	}

	return &AesContentCipher{
		keysize: keysize,
		tagsize: tagsize,
		fetch:   fetcher,
	}, nil
}

func (c AesContentCipher) Encrypt(cek, plaintext, aad []byte) (iv, ciphertxt, tag []byte, err error) {
	var aead cipher.AEAD
	aead, err = c.fetch.Fetch(cek)
	if err != nil {
		return nil, nil, nil, fmt.Errorf(`failed to fetch AEAD: %w`, err)
	}

	// Seal may panic (argh!), so protect ourselves from that
	defer func() {
		if e := recover(); e != nil {
			switch e := e.(type) {
			case error:
				err = e
			default:
				err = fmt.Errorf("%s", e)
			}
			err = fmt.Errorf(`failed to encrypt: %w`, err)
		}
	}()

	var bs keygen.ByteSource
	if c.NonceGenerator == nil {
		bs, err = keygen.NewRandom(aead.NonceSize()).Generate()
	} else {
		bs, err = c.NonceGenerator.Generate()
	}
	if err != nil {
		return nil, nil, nil, fmt.Errorf(`failed to generate nonce: %w`, err)
	}
	iv = bs.Bytes()

	combined := aead.Seal(nil, iv, plaintext, aad)
	tagoffset := len(combined) - c.TagSize()

	if tagoffset < 0 {
		panic(fmt.Sprintf("tag offset is less than 0 (combined len = %d, tagsize = %d)", len(combined), c.TagSize()))
	}

	tag = combined[tagoffset:]
	ciphertxt = make([]byte, tagoffset)
	copy(ciphertxt, combined[:tagoffset])

	return
}

func (c AesContentCipher) Decrypt(cek, iv, ciphertxt, tag, aad []byte) (plaintext []byte, err error) {
	aead, err := c.fetch.Fetch(cek)
	if err != nil {
		return nil, fmt.Errorf(`failed to fetch AEAD data: %w`, err)
	}

	// Open may panic (argh!), so protect ourselves from that
	defer func() {
		if e := recover(); e != nil {
			switch e := e.(type) {
			case error:
				err = e
			default:
				err = fmt.Errorf(`%s`, e)
			}
			err = fmt.Errorf(`failed to decrypt: %w`, err)
			return
		}
	}()

	combined := make([]byte, len(ciphertxt)+len(tag))
	copy(combined, ciphertxt)
	copy(combined[len(ciphertxt):], tag)

	buf, aeaderr := aead.Open(nil, iv, combined, aad)
	if aeaderr != nil {
		err = fmt.Errorf(`aead.Open failed: %w`, aeaderr)
		return
	}
	plaintext = buf
	return
}