File: jwe_example_test.go

package info (click to toggle)
golang-github-lestrrat-go-jwx 2.1.4-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,872 kB
  • sloc: sh: 222; makefile: 86; perl: 62
file content (120 lines) | stat: -rw-r--r-- 3,498 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package examples_test

import (
	"context"
	"crypto/rand"
	"crypto/rsa"
	"fmt"
	"log"

	"github.com/lestrrat-go/jwx/v2/internal/jwxtest"
	"github.com/lestrrat-go/jwx/v2/jwa"
	"github.com/lestrrat-go/jwx/v2/jwe"
)

func exampleGenPayload() (*rsa.PrivateKey, []byte, error) {
	privkey, err := rsa.GenerateKey(rand.Reader, 2048)
	if err != nil {
		return nil, nil, err
	}

	payload := []byte("Lorem Ipsum")

	encrypted, err := jwe.Encrypt(payload, jwe.WithKey(jwa.RSA1_5, &privkey.PublicKey), jwe.WithContentEncryption(jwa.A128CBC_HS256))
	if err != nil {
		return nil, nil, err
	}
	return privkey, encrypted, nil
}

func Example_jwe_decrypt() {
	privkey, encrypted, err := exampleGenPayload()
	if err != nil {
		log.Printf("failed to generate encrypted payload: %s", err)
		return
	}

	decrypted, err := jwe.Decrypt(encrypted, jwe.WithKey(jwa.RSA1_5, privkey))
	if err != nil {
		log.Printf("failed to decrypt: %s", err)
		return
	}

	if string(decrypted) != "Lorem Ipsum" {
		log.Printf("WHAT?!")
		return
	}
	// OUTPUT:
}

func Example_jwe_complex_decrypt() {
	// WARNING: THIS USAGE IS NOT FOR A CASUAL USER. ONLY use it when you must.
	// Only use it when you understand how JWE is supposed to work. Only use it
	// when you understand the inner workings of this code.

	// In this example, the caller wants to determine the key to use by checking
	// the value of a protected header called `jwx-hints`.

	const payload = "Hello, World!"

	privkey, err := jwxtest.GenerateRsaKey()
	if err != nil {
		fmt.Printf("failed to generate key: %s\n", err)
		return
	}

	// First we will create a sample JWE payload
	protected := jwe.NewHeaders()
	protected.Set(`jwx-hints`, `foobar`) // in real life this would a more meaningful value
	encrypted, err := jwe.Encrypt(
		[]byte(payload),
		jwe.WithKey(jwa.RSA_OAEP, privkey.PublicKey),
		jwe.WithProtectedHeaders(protected),
	)
	if err != nil {
		fmt.Printf("failed to encrypt message\n")
		return
	}

	// The party responsible to determining the key is the jwe.KeyProvider hook.
	//
	// Here we are using a function turned into an interface for brevity, but in real life
	// I would personally recommend creating a real type for your specific needs
	// instead of passing adhoc closures. YMMV.
	kp := func(ctx context.Context, sink jwe.KeySink, _ jwe.Recipient, msg *jwe.Message) error {
		rawhint, _ := msg.ProtectedHeaders().Get(`jwx-hints`)
		//nolint:forcetypeassert
		hint, ok := rawhint.(string)
		if ok && hint == `foobar` {
			// This is where we are setting the key to be used.
			//
			// In real life you would look up the key or something.
			// Here we just assign the key to use.
			//
			// You may opt to set both the algorithm and key here as well.
			// BUT BE CAREFUL so that you don't accidentally create a
			// vulnerability
			sink.Key(jwa.RSA_OAEP, privkey)
			return nil
		}

		// If there were errors, just return it, and the whole jwe.Decrypt will fail.
		return fmt.Errorf(`invalid value for jwx-hints: %s`, rawhint)
	}

	// Calling jwe.Decrypt with the extra argument of jwe.WithPostParser().
	// Here we pass a nil key to jwe.Decrypt, because the PostParser will be
	// determining the key to use when its PostParse() method is called
	decrypted, err := jwe.Decrypt(encrypted, jwe.WithKeyProvider(jwe.KeyProviderFunc(kp)))
	if err != nil {
		fmt.Printf("failed to decrypt message: %s\n", err)
		return
	}

	if string(decrypted) != payload {
		fmt.Printf("wrong decrypted payload: %s\n", decrypted)
		return
	}

	// OUTPUT:
}