1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
|
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
from unittest import skipUnless
from immutabledict import immutabledict
from parameterized import parameterized_class
from synapse.api.errors import SynapseError
from synapse.types import (
AbstractMultiWriterStreamToken,
MultiWriterStreamToken,
RoomAlias,
RoomStreamToken,
UserID,
get_domain_from_id,
get_localpart_from_id,
map_username_to_mxid_localpart,
)
from tests import unittest
from tests.utils import USE_POSTGRES_FOR_TESTS
class IsMineIDTests(unittest.HomeserverTestCase):
def test_is_mine_id(self) -> None:
self.assertTrue(self.hs.is_mine_id("@user:test"))
self.assertTrue(self.hs.is_mine_id("#room:test"))
self.assertTrue(self.hs.is_mine_id("invalid:test"))
self.assertFalse(self.hs.is_mine_id("@user:test\0"))
self.assertFalse(self.hs.is_mine_id("@user"))
def test_two_colons(self) -> None:
"""Test handling of IDs containing more than one colon."""
# The domain starts after the first colon.
# These functions must interpret things consistently.
self.assertFalse(self.hs.is_mine_id("@user:test:test"))
self.assertEqual("user", get_localpart_from_id("@user:test:test"))
self.assertEqual("test:test", get_domain_from_id("@user:test:test"))
class UserIDTestCase(unittest.HomeserverTestCase):
def test_parse(self) -> None:
user = UserID.from_string("@1234abcd:test")
self.assertEqual("1234abcd", user.localpart)
self.assertEqual("test", user.domain)
self.assertEqual(True, self.hs.is_mine(user))
def test_parse_rejects_empty_id(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("")
def test_parse_rejects_missing_sigil(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("alice:example.com")
def test_parse_rejects_missing_separator(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("@alice.example.com")
def test_validation_rejects_missing_domain(self) -> None:
self.assertFalse(UserID.is_valid("@alice:"))
def test_build(self) -> None:
user = UserID("5678efgh", "my.domain")
self.assertEqual(user.to_string(), "@5678efgh:my.domain")
def test_compare(self) -> None:
userA = UserID.from_string("@userA:my.domain")
userAagain = UserID.from_string("@userA:my.domain")
userB = UserID.from_string("@userB:my.domain")
self.assertTrue(userA == userAagain)
self.assertTrue(userA != userB)
class RoomAliasTestCase(unittest.HomeserverTestCase):
def test_parse(self) -> None:
room = RoomAlias.from_string("#channel:test")
self.assertEqual("channel", room.localpart)
self.assertEqual("test", room.domain)
self.assertEqual(True, self.hs.is_mine(room))
def test_build(self) -> None:
room = RoomAlias("channel", "my.domain")
self.assertEqual(room.to_string(), "#channel:my.domain")
def test_validate(self) -> None:
id_string = "#test:domain,test"
self.assertFalse(RoomAlias.is_valid(id_string))
class MapUsernameTestCase(unittest.TestCase):
def test_pass_througuh(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
def test_upper_case(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
self.assertEqual(
map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
"t_e_s_t__1234",
)
def test_symbols(self) -> None:
self.assertEqual(
map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234"
)
def test_leading_underscore(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")
def test_non_ascii(self) -> None:
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
@parameterized_class(
("token_type",),
[
(MultiWriterStreamToken,),
(RoomStreamToken,),
],
class_name_func=lambda cls,
num,
params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}",
)
class MultiWriterTokenTestCase(unittest.HomeserverTestCase):
"""Tests for the different types of multi writer tokens."""
token_type: type[AbstractMultiWriterStreamToken]
def test_basic_token(self) -> None:
"""Test that a simple stream token can be serialized and unserialized"""
store = self.hs.get_datastores().main
token = self.token_type(stream=5)
string_token = self.get_success(token.to_string(store))
if isinstance(token, RoomStreamToken):
self.assertEqual(string_token, "s5")
else:
self.assertEqual(string_token, "5")
parsed_token = self.get_success(self.token_type.parse(store, string_token))
self.assertEqual(parsed_token, token)
@skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres")
def test_instance_map(self) -> None:
"""Test for stream token with instance map"""
store = self.hs.get_datastores().main
token = self.token_type(stream=5, instance_map=immutabledict({"foo": 6}))
string_token = self.get_success(token.to_string(store))
self.assertEqual(string_token, "m5~1.6")
parsed_token = self.get_success(self.token_type.parse(store, string_token))
self.assertEqual(parsed_token, token)
def test_instance_map_assertion(self) -> None:
"""Test that we assert values in the instance map are greater than the
min stream position"""
with self.assertRaises(ValueError):
self.token_type(stream=5, instance_map=immutabledict({"foo": 4}))
with self.assertRaises(ValueError):
self.token_type(stream=5, instance_map=immutabledict({"foo": 5}))
def test_parse_bad_token(self) -> None:
"""Test that we can parse tokens produced by a bug in Synapse of the
form `m5~`"""
store = self.hs.get_datastores().main
parsed_token = self.get_success(self.token_type.parse(store, "m5~"))
self.assertEqual(parsed_token, self.token_type(stream=5))
|