import unittest
from unittest.mock import Mock

from nbxmpp import dispatcher


class XMLVulnerability(unittest.TestCase):

    def setUp(self):
        self.stream = Mock()
        self.stream.is_websocket = False
        self.dispatcher = dispatcher.StanzaDispatcher(self.stream)
        self._error_handler = Mock()
        self.dispatcher.subscribe('parsing-error', self._error_handler)
        self.dispatcher.reset_parser()

    def test_exponential_entity_expansion(self):
        bomb = """<?xml version="1.0" encoding="utf-8"?>
        <!DOCTYPE bomb [
            <!ENTITY a "test">
            <!ENTITY b "&a;&a;&a;&a;&a;&a;&a;&a;&a;&a;&a;&a;&a;&a;&a;&a;">
            <!ENTITY c "&b;&b;&b;&b;&b;&b;&b;&b;&b;&b;&b;&b;&b;&b;&b;&b;">
        ]>
        <bomb>&c;</bomb>"""

        self.dispatcher.process_data(bomb)
        self._error_handler.assert_called()

    def test_quadratic_blowup(self):
        bomb = """<?xml version="1.0" encoding="utf-8"?>
        <!DOCTYPE bomb [
        <!ENTITY a "xxxxxxx... a couple of ten thousand chars">
        ]>
        <bomb>&a;&a;&a;... repeat</bomb>"""

        self.dispatcher.process_data(bomb)
        self._error_handler.assert_called()

    def test_external_entity_expansion(self):
        bomb = """<?xml version="1.0" encoding="utf-8"?>
        <!DOCTYPE external [
        <!ENTITY ee SYSTEM "http://www.python.org/some.xml">
        ]>
        <root>&ee;</root>"""

        self.dispatcher.process_data(bomb)
        self._error_handler.assert_called()

    def test_external_local_entity_expansion(self):
        bomb = """<?xml version="1.0" encoding="utf-8"?>
        <stream:stream xmlns:stream='http://etherx.jabber.org/streams' xmlns='jabber:client'>
        <!DOCTYPE external [
        <!ENTITY ee SYSTEM "file:///PATH/TO/simple.xml">
        ]>
        <root>&ee;</root>"""

        self.dispatcher.process_data(bomb)
        self._error_handler.assert_called()

    def test_dtd_retrival(self):
        bomb = """<?xml version="1.0" encoding="utf-8"?>
        <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
          "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
        <html>
            <head/>
            <body>text</body>
        </html>"""

        self.dispatcher.process_data(bomb)
        self._error_handler.assert_called()


if __name__ == '__main__':
    unittest.main()
