1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
package customizations
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"github.com/aws/smithy-go"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
// AddTreeHashMiddleware adds middleware needed to automatically
// calculate Glacier's required checksum headers.
func AddTreeHashMiddleware(stack *middleware.Stack) error {
return stack.Finalize.Add(&TreeHash{}, middleware.Before)
}
// TreeHash provides the middleware that will automatically
// set the sha256 and tree hash headers if they have not already been
// set.
type TreeHash struct{}
// ID returns the middleware ID.
func (*TreeHash) ID() string {
return "Glacier:TreeHash"
}
// HandleFinalize implements the finalize middleware handler method
func (*TreeHash) HandleFinalize(
ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler,
) (
output middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
req, ok := input.Request.(*smithyhttp.Request)
if !ok {
return output, metadata, &smithy.SerializationError{
Err: fmt.Errorf("unknown request type %T", input.Request),
}
}
if err := addChecksum(req); err != nil {
return output, metadata, &smithy.SerializationError{Err: err}
}
return next.HandleFinalize(ctx, input)
}
func addChecksum(req *smithyhttp.Request) error {
if req.GetStream() == nil || req.Header.Get("X-Amz-Sha256-Tree-Hash") != "" {
return nil
}
if !req.IsStreamSeekable() {
return fmt.Errorf("glacier content-sha26 and tree hash can only be automatically computed if the request body is seekable")
}
h := computeHashes(req.GetStream())
if err := req.RewindStream(); err != nil {
return err
}
hstr := hex.EncodeToString(h.TreeHash)
req.Header.Set("X-Amz-Sha256-Tree-Hash", hstr)
hLstr := hex.EncodeToString(h.LinearHash)
req.Header.Set("X-Amz-Content-Sha256", hLstr)
return nil
}
// Hash contains information about the tree-hash and linear hash of a
// Glacier payload. This structure is generated by computeHashes().
type Hash struct {
TreeHash []byte
LinearHash []byte
}
// computeHashes computes the tree-hash and linear hash of a reader r.
//
// Note that this does not perform seeks before or after, these must be done manually.
//
// See http://docs.aws.amazon.com/amazonglacier/latest/dev/checksum-calculations.html for more information.
func computeHashes(r io.Reader) Hash {
const bufsize = 1024 * 1024
buf := make([]byte, bufsize)
var hashes [][]byte
hsh := sha256.New()
for {
// Build leaf nodes in 1MB chunks
n, err := io.ReadAtLeast(r, buf, bufsize)
if n == 0 {
break
}
tmpHash := sha256.Sum256(buf[:n])
hashes = append(hashes, tmpHash[:])
hsh.Write(buf[:n]) // Track linear hash while we're at it
if err != nil {
break // This is the last chunk
}
}
return Hash{
LinearHash: hsh.Sum(nil),
TreeHash: computeTreeHash(hashes),
}
}
// computeTreeHash builds a tree hash root node given a slice of
// hashes. Glacier tree hash to be derived from SHA256 hashes of 1MB
// chucks of the data.
//
// See http://docs.aws.amazon.com/amazonglacier/latest/dev/checksum-calculations.html for more information.
func computeTreeHash(hashes [][]byte) []byte {
hashCount := len(hashes)
switch hashCount {
case 0:
return nil
case 1:
return hashes[0]
}
leaves := make([][32]byte, hashCount)
for i := range leaves {
copy(leaves[i][:], hashes[i])
}
var (
queue = leaves[:0]
h256 = sha256.New()
buf [32]byte
)
for len(leaves) > 1 {
for i := 0; i < len(leaves); i += 2 {
if i+1 == len(leaves) {
queue = append(queue, leaves[i])
break
}
h256.Write(leaves[i][:])
h256.Write(leaves[i+1][:])
h256.Sum(buf[:0])
queue = append(queue, buf)
h256.Reset()
}
leaves = queue
queue = queue[:0]
}
return leaves[0][:]
}
|