File: rsa.go

package info (click to toggle)
golang-github-lestrrat-go-jwx 2.1.4-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,872 kB
  • sloc: sh: 222; makefile: 86; perl: 62
file content (146 lines) | stat: -rw-r--r-- 3,250 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package jws

import (
	"crypto"
	"crypto/rand"
	"crypto/rsa"
	"fmt"

	"github.com/lestrrat-go/jwx/v2/internal/keyconv"
	"github.com/lestrrat-go/jwx/v2/jwa"
)

var rsaSigners map[jwa.SignatureAlgorithm]*rsaSigner
var rsaVerifiers map[jwa.SignatureAlgorithm]*rsaVerifier

func init() {
	algs := map[jwa.SignatureAlgorithm]struct {
		Hash crypto.Hash
		PSS  bool
	}{
		jwa.RS256: {
			Hash: crypto.SHA256,
		},
		jwa.RS384: {
			Hash: crypto.SHA384,
		},
		jwa.RS512: {
			Hash: crypto.SHA512,
		},
		jwa.PS256: {
			Hash: crypto.SHA256,
			PSS:  true,
		},
		jwa.PS384: {
			Hash: crypto.SHA384,
			PSS:  true,
		},
		jwa.PS512: {
			Hash: crypto.SHA512,
			PSS:  true,
		},
	}

	rsaSigners = make(map[jwa.SignatureAlgorithm]*rsaSigner)
	rsaVerifiers = make(map[jwa.SignatureAlgorithm]*rsaVerifier)
	for alg, item := range algs {
		rsaSigners[alg] = &rsaSigner{
			alg:  alg,
			hash: item.Hash,
			pss:  item.PSS,
		}
		rsaVerifiers[alg] = &rsaVerifier{
			alg:  alg,
			hash: item.Hash,
			pss:  item.PSS,
		}
	}
}

type rsaSigner struct {
	alg  jwa.SignatureAlgorithm
	hash crypto.Hash
	pss  bool
}

func newRSASigner(alg jwa.SignatureAlgorithm) Signer {
	return rsaSigners[alg]
}

func (rs *rsaSigner) Algorithm() jwa.SignatureAlgorithm {
	return rs.alg
}

func (rs *rsaSigner) Sign(payload []byte, key interface{}) ([]byte, error) {
	if key == nil {
		return nil, fmt.Errorf(`missing private key while signing payload`)
	}

	signer, ok := key.(crypto.Signer)
	if ok {
		if !isValidRSAKey(key) {
			return nil, fmt.Errorf(`cannot use key of type %T to generate RSA based signatures`, key)
		}
	} else {
		var privkey rsa.PrivateKey
		if err := keyconv.RSAPrivateKey(&privkey, key); err != nil {
			return nil, fmt.Errorf(`failed to retrieve rsa.PrivateKey out of %T: %w`, key, err)
		}
		signer = &privkey
	}

	h := rs.hash.New()
	if _, err := h.Write(payload); err != nil {
		return nil, fmt.Errorf(`failed to write payload to hash: %w`, err)
	}
	if rs.pss {
		return signer.Sign(rand.Reader, h.Sum(nil), &rsa.PSSOptions{
			Hash:       rs.hash,
			SaltLength: rsa.PSSSaltLengthEqualsHash,
		})
	}
	return signer.Sign(rand.Reader, h.Sum(nil), rs.hash)
}

type rsaVerifier struct {
	alg  jwa.SignatureAlgorithm
	hash crypto.Hash
	pss  bool
}

func newRSAVerifier(alg jwa.SignatureAlgorithm) Verifier {
	return rsaVerifiers[alg]
}

func (rv *rsaVerifier) Verify(payload, signature []byte, key interface{}) error {
	if key == nil {
		return fmt.Errorf(`missing public key while verifying payload`)
	}

	var pubkey rsa.PublicKey
	if cs, ok := key.(crypto.Signer); ok {
		cpub := cs.Public()
		switch cpub := cpub.(type) {
		case rsa.PublicKey:
			pubkey = cpub
		case *rsa.PublicKey:
			pubkey = *cpub
		default:
			return fmt.Errorf(`failed to retrieve rsa.PublicKey out of crypto.Signer %T`, key)
		}
	} else {
		if err := keyconv.RSAPublicKey(&pubkey, key); err != nil {
			return fmt.Errorf(`failed to retrieve rsa.PublicKey out of %T: %w`, key, err)
		}
	}

	h := rv.hash.New()
	if _, err := h.Write(payload); err != nil {
		return fmt.Errorf(`failed to write payload to hash: %w`, err)
	}

	if rv.pss {
		return rsa.VerifyPSS(&pubkey, rv.hash, h.Sum(nil), signature, nil)
	}
	return rsa.VerifyPKCS1v15(&pubkey, rv.hash, h.Sum(nil), signature)
}