File: download_stream.go

package info (click to toggle)
golang-mongodb-mongo-driver 1.8.4%2Bds1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, bookworm-backports
  • size: 18,520 kB
  • sloc: perl: 533; ansic: 491; python: 432; makefile: 187; sh: 72
file content (284 lines) | stat: -rw-r--r-- 7,525 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package gridfs

import (
	"context"
	"errors"
	"io"
	"math"
	"time"

	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/mongo"
)

// ErrWrongIndex is used when the chunk retrieved from the server does not have the expected index.
var ErrWrongIndex = errors.New("chunk index does not match expected index")

// ErrWrongSize is used when the chunk retrieved from the server does not have the expected size.
var ErrWrongSize = errors.New("chunk size does not match expected size")

var errNoMoreChunks = errors.New("no more chunks remaining")

// DownloadStream is a io.Reader that can be used to download a file from a GridFS bucket.
type DownloadStream struct {
	numChunks     int32
	chunkSize     int32
	cursor        *mongo.Cursor
	done          bool
	closed        bool
	buffer        []byte // store up to 1 chunk if the user provided buffer isn't big enough
	bufferStart   int
	bufferEnd     int
	expectedChunk int32 // index of next expected chunk
	readDeadline  time.Time
	fileLen       int64

	// The pointer returned by GetFile. This should not be used in the actual DownloadStream code outside of the
	// newDownloadStream constructor because the values can be mutated by the user after calling GetFile. Instead,
	// any values needed in the code should be stored separately and copied over in the constructor.
	file *File
}

// File represents a file stored in GridFS. This type can be used to access file information when downloading using the
// DownloadStream.GetFile method.
type File struct {
	// ID is the file's ID. This will match the file ID specified when uploading the file. If an upload helper that
	// does not require a file ID was used, this field will be a primitive.ObjectID.
	ID interface{}

	// Length is the length of this file in bytes.
	Length int64

	// ChunkSize is the maximum number of bytes for each chunk in this file.
	ChunkSize int32

	// UploadDate is the time this file was added to GridFS in UTC. This field is set by the driver and is not configurable.
	// The Metadata field can be used to store a custom date.
	UploadDate time.Time

	// Name is the name of this file.
	Name string

	// Metadata is additional data that was specified when creating this file. This field can be unmarshalled into a
	// custom type using the bson.Unmarshal family of functions.
	Metadata bson.Raw
}

var _ bson.Unmarshaler = (*File)(nil)

// unmarshalFile is a temporary type used to unmarshal documents from the files collection and can be transformed into
// a File instance. This type exists to avoid adding BSON struct tags to the exported File type.
type unmarshalFile struct {
	ID         interface{} `bson:"_id"`
	Length     int64       `bson:"length"`
	ChunkSize  int32       `bson:"chunkSize"`
	UploadDate time.Time   `bson:"uploadDate"`
	Name       string      `bson:"filename"`
	Metadata   bson.Raw    `bson:"metadata"`
}

// UnmarshalBSON implements the bson.Unmarshaler interface.
func (f *File) UnmarshalBSON(data []byte) error {
	var temp unmarshalFile
	if err := bson.Unmarshal(data, &temp); err != nil {
		return err
	}

	f.ID = temp.ID
	f.Length = temp.Length
	f.ChunkSize = temp.ChunkSize
	f.UploadDate = temp.UploadDate
	f.Name = temp.Name
	f.Metadata = temp.Metadata
	return nil
}

func newDownloadStream(cursor *mongo.Cursor, chunkSize int32, file *File) *DownloadStream {
	numChunks := int32(math.Ceil(float64(file.Length) / float64(chunkSize)))

	return &DownloadStream{
		numChunks: numChunks,
		chunkSize: chunkSize,
		cursor:    cursor,
		buffer:    make([]byte, chunkSize),
		done:      cursor == nil,
		fileLen:   file.Length,
		file:      file,
	}
}

// Close closes this download stream.
func (ds *DownloadStream) Close() error {
	if ds.closed {
		return ErrStreamClosed
	}

	ds.closed = true
	if ds.cursor != nil {
		return ds.cursor.Close(context.Background())
	}
	return nil
}

// SetReadDeadline sets the read deadline for this download stream.
func (ds *DownloadStream) SetReadDeadline(t time.Time) error {
	if ds.closed {
		return ErrStreamClosed
	}

	ds.readDeadline = t
	return nil
}

// Read reads the file from the server and writes it to a destination byte slice.
func (ds *DownloadStream) Read(p []byte) (int, error) {
	if ds.closed {
		return 0, ErrStreamClosed
	}

	if ds.done {
		return 0, io.EOF
	}

	ctx, cancel := deadlineContext(ds.readDeadline)
	if cancel != nil {
		defer cancel()
	}

	bytesCopied := 0
	var err error
	for bytesCopied < len(p) {
		if ds.bufferStart >= ds.bufferEnd {
			// Buffer is empty and can load in data from new chunk.
			err = ds.fillBuffer(ctx)
			if err != nil {
				if err == errNoMoreChunks {
					if bytesCopied == 0 {
						ds.done = true
						return 0, io.EOF
					}
					return bytesCopied, nil
				}
				return bytesCopied, err
			}
		}

		copied := copy(p[bytesCopied:], ds.buffer[ds.bufferStart:ds.bufferEnd])

		bytesCopied += copied
		ds.bufferStart += copied
	}

	return len(p), nil
}

// Skip skips a given number of bytes in the file.
func (ds *DownloadStream) Skip(skip int64) (int64, error) {
	if ds.closed {
		return 0, ErrStreamClosed
	}

	if ds.done {
		return 0, nil
	}

	ctx, cancel := deadlineContext(ds.readDeadline)
	if cancel != nil {
		defer cancel()
	}

	var skipped int64
	var err error

	for skipped < skip {
		if ds.bufferStart >= ds.bufferEnd {
			// Buffer is empty and can load in data from new chunk.
			err = ds.fillBuffer(ctx)
			if err != nil {
				if err == errNoMoreChunks {
					return skipped, nil
				}
				return skipped, err
			}
		}

		toSkip := skip - skipped
		// Cap the amount to skip to the remaining bytes in the buffer to be consumed.
		bufferRemaining := ds.bufferEnd - ds.bufferStart
		if toSkip > int64(bufferRemaining) {
			toSkip = int64(bufferRemaining)
		}

		skipped += toSkip
		ds.bufferStart += int(toSkip)
	}

	return skip, nil
}

// GetFile returns a File object representing the file being downloaded.
func (ds *DownloadStream) GetFile() *File {
	return ds.file
}

func (ds *DownloadStream) fillBuffer(ctx context.Context) error {
	if !ds.cursor.Next(ctx) {
		ds.done = true
		// Check for cursor error, otherwise there are no more chunks.
		if ds.cursor.Err() != nil {
			_ = ds.cursor.Close(ctx)
			return ds.cursor.Err()
		}
		return errNoMoreChunks
	}

	chunkIndex, err := ds.cursor.Current.LookupErr("n")
	if err != nil {
		return err
	}

	var chunkIndexInt32 int32
	if chunkIndexInt64, ok := chunkIndex.Int64OK(); ok {
		chunkIndexInt32 = int32(chunkIndexInt64)
	} else {
		chunkIndexInt32 = chunkIndex.Int32()
	}

	if chunkIndexInt32 != ds.expectedChunk {
		return ErrWrongIndex
	}

	ds.expectedChunk++
	data, err := ds.cursor.Current.LookupErr("data")
	if err != nil {
		return err
	}

	_, dataBytes := data.Binary()
	copied := copy(ds.buffer, dataBytes)

	bytesLen := int32(len(dataBytes))
	if ds.expectedChunk == ds.numChunks {
		// final chunk can be fewer than ds.chunkSize bytes
		bytesDownloaded := int64(ds.chunkSize) * (int64(ds.expectedChunk) - int64(1))
		bytesRemaining := ds.fileLen - bytesDownloaded

		if int64(bytesLen) != bytesRemaining {
			return ErrWrongSize
		}
	} else if bytesLen != ds.chunkSize {
		// all intermediate chunks must have size ds.chunkSize
		return ErrWrongSize
	}

	ds.bufferStart = 0
	ds.bufferEnd = copied

	return nil
}