1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
|
"""
Implements a thread pool for parallel copying of files.
"""
from __future__ import unicode_literals
import typing
import threading
from queue import Queue
from .copy import copy_file_internal, copy_modified_time
from .errors import BulkCopyFailed
from .tools import copy_file_data
if typing.TYPE_CHECKING:
from typing import IO, List, Optional, Text, Tuple, Type
from types import TracebackType
from .base import FS
class _Worker(threading.Thread):
"""Worker thread that pulls tasks from a queue."""
def __init__(self, copier):
# type (Copier) -> None
self.copier = copier
super(_Worker, self).__init__()
self.daemon = True
def run(self):
# type () -> None
queue = self.copier.queue
while True:
task = queue.get(block=True)
try:
if task is None:
break # Sentinel to exit thread
task()
except Exception as error:
self.copier.add_error(error)
finally:
queue.task_done()
class _Task(object):
"""Base class for a task."""
def __call__(self):
# type: () -> None
"""Task implementation."""
class _CopyTask(_Task):
"""A callable that copies from one file another."""
def __init__(self, src_file, dst_file):
# type: (IO, IO) -> None
self.src_file = src_file
self.dst_file = dst_file
def __call__(self):
# type: () -> None
try:
copy_file_data(self.src_file, self.dst_file, chunk_size=1024 * 1024)
finally:
try:
self.src_file.close()
finally:
self.dst_file.close()
class Copier(object):
"""Copy files in worker threads."""
def __init__(self, num_workers=4, preserve_time=False):
# type: (int, bool) -> None
if num_workers < 0:
raise ValueError("num_workers must be >= 0")
self.num_workers = num_workers
self.preserve_time = preserve_time
self.all_tasks = [] # type: List[Tuple[FS, Text, FS, Text]]
self.queue = None # type: Optional[Queue[_Task]]
self.workers = [] # type: List[_Worker]
self.errors = [] # type: List[Exception]
self.running = False
def start(self):
"""Start the workers."""
if self.num_workers:
self.queue = Queue(maxsize=self.num_workers)
self.workers = [_Worker(self) for _ in range(self.num_workers)]
for worker in self.workers:
worker.start()
self.running = True
def stop(self):
"""Stop the workers (will block until they are finished)."""
if self.running and self.num_workers:
# Notify the workers that all tasks have arrived
# and wait for them to finish.
for _worker in self.workers:
self.queue.put(None)
for worker in self.workers:
worker.join()
# If the "last modified" time is to be preserved, do it now.
if self.preserve_time:
for args in self.all_tasks:
copy_modified_time(*args)
# Free up references held by workers
del self.workers[:]
self.queue.join()
self.running = False
def add_error(self, error):
"""Add an exception raised by a task."""
self.errors.append(error)
def __enter__(self):
self.start()
return self
def __exit__(
self,
exc_type, # type: Optional[Type[BaseException]]
exc_value, # type: Optional[BaseException]
traceback, # type: Optional[TracebackType]
):
self.stop()
if traceback is None and self.errors:
raise BulkCopyFailed(self.errors)
def copy(self, src_fs, src_path, dst_fs, dst_path, preserve_time=False):
# type: (FS, Text, FS, Text, bool) -> None
"""Copy a file from one fs to another."""
if self.queue is None:
# This should be the most performant for a single-thread
copy_file_internal(
src_fs, src_path, dst_fs, dst_path, preserve_time=self.preserve_time
)
else:
self.all_tasks.append((src_fs, src_path, dst_fs, dst_path))
src_file = src_fs.openbin(src_path, "r")
try:
dst_file = dst_fs.openbin(dst_path, "w")
except Exception:
src_file.close()
raise
task = _CopyTask(src_file, dst_file)
self.queue.put(task)
|