1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
|
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package dtls
import (
"context"
"github.com/pion/dtls/v3/pkg/crypto/elliptic"
"github.com/pion/dtls/v3/pkg/protocol"
"github.com/pion/dtls/v3/pkg/protocol/alert"
"github.com/pion/dtls/v3/pkg/protocol/extension"
"github.com/pion/dtls/v3/pkg/protocol/handshake"
"github.com/pion/dtls/v3/pkg/protocol/recordlayer"
)
func flight1Parse(
ctx context.Context,
conn flightConn,
state *State,
cache *handshakeCache,
cfg *handshakeConfig,
) (flightVal, *alert.Alert, error) {
// HelloVerifyRequest can be skipped by the server,
// so allow ServerHello during flight1 also
seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true},
handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, true},
)
if !ok {
// No valid message received. Keep reading
return 0, nil, nil
}
if _, ok := msgs[handshake.TypeServerHello]; ok {
// Flight1 and flight2 were skipped.
// Parse as flight3.
return flight3Parse(ctx, conn, state, cache, cfg)
}
if h, ok := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); ok {
// DTLS 1.2 clients must not assume that the server will use the protocol version
// specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
}
state.cookie = append([]byte{}, h.Cookie...)
state.handshakeRecvSequence = seq
return flight3, nil, nil
}
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
}
//nolint:cyclop
func flight1Generate(
conn flightConn,
state *State,
_ *handshakeCache,
cfg *handshakeConfig,
) ([]*packet, *alert.Alert, error) {
var zeroEpoch uint16
state.localEpoch.Store(zeroEpoch)
state.remoteEpoch.Store(zeroEpoch)
state.namedCurve = defaultNamedCurve
state.cookie = nil
if err := state.localRandom.Populate(); err != nil {
return nil, nil, err
}
if cfg.helloRandomBytesGenerator != nil {
state.localRandom.RandomBytes = cfg.helloRandomBytesGenerator()
}
extensions := []extension.Extension{
&extension.SupportedSignatureAlgorithms{
SignatureHashAlgorithms: cfg.localSignatureSchemes,
},
&extension.RenegotiationInfo{
RenegotiatedConnection: 0,
},
}
var setEllipticCurveCryptographyClientHelloExtensions bool
for _, c := range cfg.localCipherSuites {
if c.ECC() {
setEllipticCurveCryptographyClientHelloExtensions = true
break
}
}
if setEllipticCurveCryptographyClientHelloExtensions {
extensions = append(extensions, []extension.Extension{
&extension.SupportedEllipticCurves{
EllipticCurves: cfg.ellipticCurves,
},
&extension.SupportedPointFormats{
PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
},
}...)
}
if len(cfg.localSRTPProtectionProfiles) > 0 {
extensions = append(extensions, &extension.UseSRTP{
ProtectionProfiles: cfg.localSRTPProtectionProfiles,
MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier,
})
}
if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
cfg.extendedMasterSecret == RequireExtendedMasterSecret {
extensions = append(extensions, &extension.UseExtendedMasterSecret{
Supported: true,
})
}
if len(cfg.serverName) > 0 {
extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName})
}
if len(cfg.supportedProtocols) > 0 {
extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols})
}
if cfg.sessionStore != nil {
cfg.log.Tracef("[handshake] try to resume session")
if s, err := cfg.sessionStore.Get(conn.sessionKey()); err != nil {
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
} else if s.ID != nil {
cfg.log.Tracef("[handshake] get saved session: %x", s.ID)
state.SessionID = s.ID
state.masterSecret = s.Secret
}
}
// If we have a connection ID generator, use it. The CID may be zero length,
// in which case we are just requesting that the server send us a CID to
// use.
if cfg.connectionIDGenerator != nil {
state.setLocalConnectionID(cfg.connectionIDGenerator())
// The presence of a generator indicates support for connection IDs. We
// use the presence of a non-nil local CID in flight 3 to determine
// whether we send a CID in the second ClientHello, so we convert any
// nil CID returned by a generator to []byte{}.
if state.getLocalConnectionID() == nil {
state.setLocalConnectionID([]byte{})
}
extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()})
}
clientHello := &handshake.MessageClientHello{
Version: protocol.Version1_2,
SessionID: state.SessionID,
Cookie: state.cookie,
Random: state.localRandom,
CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
CompressionMethods: defaultCompressionMethods(),
Extensions: extensions,
}
var content handshake.Handshake
if cfg.clientHelloMessageHook != nil {
content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)}
} else {
content = handshake.Handshake{Message: clientHello}
}
return []*packet{
{
record: &recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
},
Content: &content,
},
},
}, nil, nil
}
|