File: sasl_transport.go

package info (click to toggle)
golang-github-colinmarc-hdfs 2.3.0-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, forky, sid, trixie
  • size: 3,760 kB
  • sloc: sh: 130; xml: 40; makefile: 31
file content (82 lines) | stat: -rw-r--r-- 2,311 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
package rpc

import (
	"bytes"
	"fmt"
	"io"

	hadoop "github.com/colinmarc/hdfs/v2/internal/protocol/hadoop_common"
	"github.com/jcmturner/gokrb5/v8/crypto"
	"github.com/jcmturner/gokrb5/v8/gssapi"
	"github.com/jcmturner/gokrb5/v8/iana/keyusage"
	krbtypes "github.com/jcmturner/gokrb5/v8/types"
	"google.golang.org/protobuf/proto"
)

// saslTransport implements encrypted or signed RPC.
type saslTransport struct {
	basicTransport

	// sessionKey is the encryption key used to decrypt and encrypt the payload.
	sessionKey krbtypes.EncryptionKey
	// privacy indicates full message encryption
	privacy bool
}

// readResponse reads a SASL-wrapped RPC response.
func (t *saslTransport) readResponse(r io.Reader, method string, requestID int32, resp proto.Message) error {
	// First, read the sasl payload as a standard rpc response.
	sasl := hadoop.RpcSaslProto{}
	err := t.basicTransport.readResponse(r, method, saslRpcCallId, &sasl)
	if err != nil {
		return err
	} else if sasl.GetState() != hadoop.RpcSaslProto_WRAP {
		return fmt.Errorf("unexpected SASL state: %s", sasl.GetState().String())
	}

	// The SaslProto contains the actual payload.
	var wrapToken gssapi.WrapToken
	err = wrapToken.Unmarshal(sasl.GetToken(), true)
	if err != nil {
		return err
	}

	rrh := &hadoop.RpcResponseHeaderProto{}

	if t.privacy {
		// Decrypt the blob, which then looks like a normal RPC response.
		decrypted, err := crypto.DecryptMessage(wrapToken.Payload, t.sessionKey, keyusage.GSSAPI_ACCEPTOR_SEAL)
		if err != nil {
			return err
		}

		err = readRPCPacket(bytes.NewReader(decrypted), rrh, resp)
		if err != nil {
			return err
		}
	} else {
		// Verify the checksum; the blob is just a normal RPC response.
		_, err = wrapToken.Verify(t.sessionKey, keyusage.GSSAPI_ACCEPTOR_SEAL)
		if err != nil {
			return fmt.Errorf("unverifiable message from namenode: %s", err)
		}

		err = readRPCPacket(bytes.NewReader(wrapToken.Payload), rrh, resp)
		if err != nil {
			return err
		}
	}

	if int32(rrh.GetCallId()) != requestID {
		return errUnexpectedSequenceNumber
	} else if rrh.GetStatus() != hadoop.RpcResponseHeaderProto_SUCCESS {
		return &NamenodeError{
			method:    method,
			message:   rrh.GetErrorMsg(),
			code:      int(rrh.GetErrorDetail()),
			exception: rrh.GetExceptionClassName(),
		}
	}

	return nil
}