File: main.go

package info (click to toggle)
golang-github-kisom-goutils 0.0~git20161101.0.858c9cb-1
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 384 kB
  • ctags: 331
  • sloc: makefile: 6
file content (161 lines) | stat: -rw-r--r-- 3,331 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
package main

import (
	"bytes"
	"crypto"
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rsa"
	"crypto/x509"
	"encoding/pem"
	"errors"
	"flag"
	"fmt"
	"io/ioutil"
	"log"
	"os"

	"github.com/kisom/goutils/die"
)

var validPEMs = map[string]bool{
	"PRIVATE KEY":     true,
	"RSA PRIVATE KEY": true,
	"EC PRIVATE KEY":  true,
}

const (
	curveInvalid = iota // any invalid curve
	curveRSA            // indicates key is an RSA key, not an EC key
	curveP256
	curveP384
	curveP521
)

func getECCurve(pub interface{}) int {
	switch pub := pub.(type) {
	case *ecdsa.PublicKey:
		switch pub.Curve {
		case elliptic.P256():
			return curveP256
		case elliptic.P384():
			return curveP384
		case elliptic.P521():
			return curveP521
		default:
			return curveInvalid
		}
	case *rsa.PublicKey:
		return curveRSA
	default:
		return curveInvalid
	}
}

func loadKey(path string) (crypto.Signer, error) {
	in, err := ioutil.ReadFile(path)
	if err != nil {
		return nil, err
	}

	in = bytes.TrimSpace(in)
	p, _ := pem.Decode(in)
	if p != nil {
		if !validPEMs[p.Type] {
			return nil, errors.New("invalid private key file type " + p.Type)
		}
		in = p.Bytes
	}

	priv, err := x509.ParsePKCS8PrivateKey(in)
	if err != nil {
		priv, err = x509.ParsePKCS1PrivateKey(in)
		if err != nil {
			priv, err = x509.ParseECPrivateKey(in)
			if err != nil {
				return nil, err
			}
		}
	}

	switch priv.(type) {
	case *rsa.PrivateKey:
		return priv.(*rsa.PrivateKey), nil
	case *ecdsa.PrivateKey:
		return priv.(*ecdsa.PrivateKey), nil
	}

	// should never reach here
	return nil, errors.New("invalid private key")

}

func main() {
	var keyFile, certFile string
	flag.StringVar(&keyFile, "k", "", "TLS private `key` file")
	flag.StringVar(&certFile, "c", "", "TLS `certificate` file")
	flag.Parse()

	in, err := ioutil.ReadFile(certFile)
	die.If(err)

	p, _ := pem.Decode(in)
	if p != nil {
		if p.Type != "CERTIFICATE" {
			die.With("invalid certificate (type is %s)", p.Type)
		}
		in = p.Bytes
	}
	cert, err := x509.ParseCertificate(in)
	die.If(err)

	priv, err := loadKey(keyFile)
	die.If(err)

	switch pub := priv.Public().(type) {
	case *rsa.PublicKey:
		switch certPub := cert.PublicKey.(type) {
		case *rsa.PublicKey:
			if pub.N.Cmp(certPub.N) != 0 || pub.E != certPub.E {
				fmt.Println("No match (public keys don't match).")
				os.Exit(1)
			}
			fmt.Println("Match.")
			return
		case *ecdsa.PublicKey:
			fmt.Println("No match (RSA private key, EC public key).")
			os.Exit(1)
		}
	case *ecdsa.PublicKey:
		privCurve := getECCurve(pub)
		certCurve := getECCurve(cert.PublicKey)
		log.Printf("priv: %d\tcert: %d\n", privCurve, certCurve)

		if certCurve == curveRSA {
			fmt.Println("No match (private key is EC, certificate is RSA).")
			os.Exit(1)
		} else if privCurve == curveInvalid {
			fmt.Println("No match (invalid private key curve).")
			os.Exit(1)
		} else if privCurve != certCurve {
			fmt.Println("No match (EC curves don't match).")
			os.Exit(1)
		}

		certPub := cert.PublicKey.(*ecdsa.PublicKey)
		if pub.X.Cmp(certPub.X) != 0 {
			fmt.Println("No match (public keys don't match).")
			os.Exit(1)
		}

		if pub.Y.Cmp(certPub.Y) != 0 {
			fmt.Println("No match (public keys don't match).")
			os.Exit(1)
		}

		fmt.Println("Match.")
	default:
		fmt.Printf("Unrecognised private key type: %T\n", priv.Public())
		os.Exit(1)
	}
}