1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
|
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package hybrid
import (
"bytes"
"fmt"
"github.com/tink-crypto/tink-go/v2/internal/internalapi"
"github.com/tink-crypto/tink-go/v2/internal/internalregistry"
"github.com/tink-crypto/tink-go/v2/internal/monitoringutil"
"github.com/tink-crypto/tink-go/v2/internal/prefixmap"
"github.com/tink-crypto/tink-go/v2/internal/primitiveset"
"github.com/tink-crypto/tink-go/v2/keyset"
"github.com/tink-crypto/tink-go/v2/monitoring"
"github.com/tink-crypto/tink-go/v2/tink"
)
// NewHybridDecrypt returns an HybridDecrypt primitive from the given keyset handle.
func NewHybridDecrypt(handle *keyset.Handle) (tink.HybridDecrypt, error) {
ps, err := keyset.Primitives[tink.HybridDecrypt](handle, internalapi.Token{})
if err != nil {
return nil, fmt.Errorf("hybrid_factory: cannot obtain primitive set: %s", err)
}
return newWrappedHybridDecrypt(ps)
}
type fullHybridDecryptAdapter struct {
rawHybridDecrypt tink.HybridDecrypt
prefix []byte
}
var _ tink.HybridDecrypt = (*fullHybridDecryptAdapter)(nil)
func (d *fullHybridDecryptAdapter) Decrypt(ciphertext, contextInfo []byte) ([]byte, error) {
// This is called by `wrappedHybridDecrypt.Decrypt`, which selects the
// correct decrypter based on the prefix; if the prefix is not correct,
// this is a bug.
if len(ciphertext) < len(d.prefix) {
return nil, fmt.Errorf("ciphertext too short")
}
if !bytes.Equal(d.prefix, ciphertext[:len(d.prefix)]) {
return nil, fmt.Errorf("ciphertext does not start with the expected prefix %x", d.prefix)
}
return d.rawHybridDecrypt.Decrypt(ciphertext[len(d.prefix):], contextInfo)
}
type decrypterAndID struct {
decrypter tink.HybridDecrypt
keyID uint32
}
var _ tink.HybridDecrypt = (*decrypterAndID)(nil)
func (d *decrypterAndID) Decrypt(ciphertext, contextInfo []byte) ([]byte, error) {
return d.decrypter.Decrypt(ciphertext, contextInfo)
}
// wrappedHybridDecrypt is an HybridDecrypt implementation that uses the underlying primitive set
// for decryption.
type wrappedHybridDecrypt struct {
decrypters *prefixmap.PrefixMap[decrypterAndID]
logger monitoring.Logger
}
// compile time assertion that wrappedHybridDecrypt implements the HybridDecrypt interface.
var _ tink.HybridDecrypt = (*wrappedHybridDecrypt)(nil)
func newWrappedHybridDecrypt(ps *primitiveset.PrimitiveSet[tink.HybridDecrypt]) (*wrappedHybridDecrypt, error) {
// Make sure the primitives do not implement tink.AEAD.
decrypters := prefixmap.New[decrypterAndID]()
if isAEAD(ps.Primary.Primitive) || isAEAD(ps.Primary.FullPrimitive) {
return nil, fmt.Errorf("hybrid_factory: primary primitive must NOT implement tink.AEAD")
}
for _, primitives := range ps.Entries {
for _, p := range primitives {
if isAEAD(p.Primitive) || isAEAD(p.FullPrimitive) {
return nil, fmt.Errorf("hybrid_factory: primitive must NOT implement tink.AEAD")
}
fullPrimitive := p.FullPrimitive
if fullPrimitive == nil {
fullPrimitive = &fullHybridDecryptAdapter{
rawHybridDecrypt: p.Primitive,
prefix: p.OutputPrefix(),
}
}
decrypters.Insert(string(p.OutputPrefix()), decrypterAndID{
decrypter: fullPrimitive,
keyID: p.KeyID,
})
}
}
logger, err := createDecryptLogger(ps)
if err != nil {
return nil, err
}
return &wrappedHybridDecrypt{
decrypters: decrypters,
logger: logger,
}, nil
}
func createDecryptLogger(ps *primitiveset.PrimitiveSet[tink.HybridDecrypt]) (monitoring.Logger, error) {
if len(ps.Annotations) == 0 {
return &monitoringutil.DoNothingLogger{}, nil
}
keysetInfo, err := monitoringutil.KeysetInfoFromPrimitiveSet(ps)
if err != nil {
return nil, err
}
return internalregistry.GetMonitoringClient().NewLogger(&monitoring.Context{
KeysetInfo: keysetInfo,
Primitive: "hybrid_decrypt",
APIFunction: "decrypt",
})
}
// Decrypt decrypts the given ciphertext, verifying the integrity of contextInfo.
// It returns the corresponding plaintext if the ciphertext is authenticated.
func (a *wrappedHybridDecrypt) Decrypt(ciphertext, contextInfo []byte) ([]byte, error) {
it := a.decrypters.PrimitivesMatchingPrefix(ciphertext)
for decrypter, ok := it.Next(); ok; decrypter, ok = it.Next() {
pt, err := decrypter.Decrypt(ciphertext, contextInfo)
if err != nil {
continue
}
a.logger.Log(decrypter.keyID, len(ciphertext))
return pt, nil
}
// Nothing worked.
a.logger.LogFailure()
return nil, fmt.Errorf("hybrid_factory: decryption failed")
}
|