# debpartial-mirror - partial debian mirror package tool
# (c) 2004 Otavio Salvador <otavio@debian.org>, Nat Budin <natb@brandeis.edu>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

import unittest
import threading
import logging
import SimpleHTTPServer
import BaseHTTPServer

import os.path
import sys

from debpartial_mirror import Download
from TestBase import TestBase

# Let us make  our server quiet! ;-)
class MySimpleHTTPRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
    def log_message(self, format, *args):
        pass # don't output

# Our builtin simple HTTP server
class ServerThread(threading.Thread):
    def __init__(self, port,
                 HandlerClass = MySimpleHTTPRequestHandler,
                 ServerClass = BaseHTTPServer.HTTPServer,
                 protocol="HTTP/1.1"):
        server_address = ('', port)
        HandlerClass.protocol_version = protocol
        self.httpd = ServerClass (server_address, HandlerClass)
        threading.Thread.__init__(self)

    def run(self):
        self.httpd.serve_forever()

class DownloadTests(TestBase):
    size_list = (1000, 5000, 10000)
    running_servers = []
    port = 80
    
    def setUp (self):
        self.__createFiles()

    def tearDown(self):
        self.__removeFiles()

    def __start_http(self):
        while True:
            try:
                httpServer = ServerThread(self.port)
                httpServer.setDaemon(True)
                httpServer.start()
                break
            except:
                self.port += 1
                pass
        self.running_servers.append(self.port)

    def __get_http_ports(self):
        return self.running_servers

    def __createFile (self, name, size):
        fname = self.aux_file(name)
        cmd = "dd if=/dev/zero of=%s bs=1024 count=%s 2>/dev/null" % (fname, size)
        self.failUnlessEqual(os.system(cmd), 0,
                             'Failed to create input test file.')

    def __createFiles (self):
        for size in self.size_list:
            fname = "pkg-example_1.%d_all.deb" % size
            self.__createFile(fname, size)

    def __removeFile (self, name):
        if os.path.exists(self.aux_file(name)):
            os.remove (self.aux_file(name))
        if os.path.exists(name):
            os.remove (name)

    def __removeFiles (self):
        for size in self.size_list:
            fname = "pkg-example_1.%d_all.deb" % size
            self.__removeFile (fname)

    def test_1download_without_server(self):
        """Download: Try to grab a file while HTTPserver is down."""
        d = Download.Download(info="quiet", name="Test1")
        size = self.size_list[0]
        
        filename = "pkg-example_1.%d_all.deb" % size
        self.__createFile (filename, size)
        d.get('http://invalid/%s' % filename, filename)
        d.wait_mine()
        self.failIf(os.path.exists(filename), 
                    "Created (empty) downloaded file even if HTTP server was down.")

        del d
        self.__removeFile(filename)

    def test_2download_one(self):
        """Download: grab a file."""

        # Initialize a HTTP server
        self.__start_http()
        
        d = Download.Download(info="quiet", name="Test2")
        filename = "pkg-example_1.%d_all.deb" % self.size_list[0]
        self.__createFile (filename, self.size_list[0])
        d.get('http://localhost:%d/tests/aux/%s' % (self.running_servers[0], filename), filename)
        d.wait_mine()
        self.failUnlessEqual(os.path.getsize(filename), os.path.getsize(self.aux_file(filename)),
                             "Downloaded file of wrong size.")
        del d
        self.__removeFiles()

    def test_3download_a_non_existing_file(self):
        """Download: grab a non existing file."""
        
        # Initialize a HTTP server
        self.__start_http()
        
        d = Download.Download(info="quiet", name="Test3")
        filename = "pkg-example-invalid.%d_all.deb" % self.size_list[0]
        #self.__createFile (filename, self.size_list[0])
        d.get('http://localhost:%d/tests/aux/%s' % (self.running_servers[0], filename), filename)
        d.wait_mine()
        self.failIf(os.path.exists(filename),
                    "Downloaded file even if was not exsistent on remote server.")
        del d

    def test_4dowload_more_file(self):
        """Download: multiple file download"""
        # Initialize a HTTP server
        self.__start_http()
        self.__start_http()
        self.__start_http()

        d = Download.Download(info="quiet",name= "Test 4")
        self.__createFiles()
        for f in range (0,len(self.size_list)):
            localname = "pkg-example_1.%d_all.deb" % self.size_list[f]
            d.get('http://localhost:%d/tests/aux/%s' % (self.running_servers[f],
                                                        localname), localname)
        d.wait_mine()

        for f in range (0,len(self.size_list)):
            localname = "pkg-example_1.%d_all.deb" % self.size_list[f]
            filename = self.aux_file(localname)
            self.failUnlessEqual(os.path.getsize(filename), os.path.getsize(localname),
                                 "Downloaded file has wrong size.")
        del d
        self.__removeFiles()
        
def suite():
    suite = unittest.TestSuite()

    suite.addTest(unittest.makeSuite(DownloadTests, 'test'))

    return suite

if __name__ == '__main__':
    log = logging.getLogger()
    log_handler = logging.FileHandler(sys.argv[0][:-3] + '.log')
    log.setLevel(logging.DEBUG)
    log.addHandler(log_handler)
    unittest.main(defaultTest='suite')
