1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
|
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import total_ordering
from cassandra.cluster import Cluster
from cassandra.connection import DefaultEndPoint, EndPoint, DefaultEndPointFactory
from cassandra.metadata import _NodeInfo
from tests.integration import requiressimulacron
from tests.integration.simulacron import SimulacronCluster, PROTOCOL_VERSION
@total_ordering
class AddressEndPoint(EndPoint):
def __init__(self, address, port=9042):
self._address = address
self._port = port
@property
def address(self):
return self._address
@property
def port(self):
return self._port
def resolve(self):
return self._address, self._port # connection purpose
def __eq__(self, other):
return isinstance(other, AddressEndPoint) and \
self.address == other.address
def __hash__(self):
return hash(self.address)
def __lt__(self, other):
return self.address < other.address
def __str__(self):
return str("%s" % self.address)
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self.address)
class AddressEndPointFactory(DefaultEndPointFactory):
def create(self, row):
addr = _NodeInfo.get_broadcast_rpc_address(row)
return AddressEndPoint(addr)
@requiressimulacron
class EndPointTests(SimulacronCluster):
"""
Basic tests to validate the internal use of the EndPoint class.
@since 3.18
@jira_ticket PYTHON-1079
@expected_result all the hosts are using the proper endpoint class
"""
def test_default_endpoint(self):
hosts = self.cluster.metadata.all_hosts()
self.assertEqual(len(hosts), 3)
for host in hosts:
self.assertIsNotNone(host.endpoint)
self.assertIsInstance(host.endpoint, DefaultEndPoint)
self.assertEqual(host.address, host.endpoint.address)
self.assertEqual(host.broadcast_rpc_address, host.endpoint.address)
self.assertIsInstance(self.cluster.control_connection._connection.endpoint, DefaultEndPoint)
self.assertIsNotNone(self.cluster.control_connection._connection.endpoint)
endpoints = [host.endpoint for host in hosts]
self.assertIn(self.cluster.control_connection._connection.endpoint, endpoints)
def test_custom_endpoint(self):
cluster = Cluster(
contact_points=[AddressEndPoint('127.0.0.1')],
protocol_version=PROTOCOL_VERSION,
endpoint_factory=AddressEndPointFactory(),
compression=False,
)
cluster.connect(wait_for_all_pools=True)
hosts = cluster.metadata.all_hosts()
self.assertEqual(len(hosts), 3)
for host in hosts:
self.assertIsNotNone(host.endpoint)
self.assertIsInstance(host.endpoint, AddressEndPoint)
self.assertEqual(str(host.endpoint), host.endpoint.address)
self.assertEqual(host.address, host.endpoint.address)
self.assertEqual(host.broadcast_rpc_address, host.endpoint.address)
self.assertIsInstance(cluster.control_connection._connection.endpoint, AddressEndPoint)
self.assertIsNotNone(cluster.control_connection._connection.endpoint)
endpoints = [host.endpoint for host in hosts]
self.assertIn(cluster.control_connection._connection.endpoint, endpoints)
cluster.shutdown()
|