# Owner(s): ["oncall: distributed"]

import os
import sys
import tempfile
import time
from datetime import timedelta
from sys import platform

import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc

if not dist.is_available():
    print("torch.distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

import torch.testing._internal.common_utils as common
from torch._six import string_classes
from torch.testing._internal.common_distributed import (
    skip_if_win32,
    create_tcp_store
)
from torch.testing._internal.common_utils import (
    TestCase,
    load_tests,
    run_tests,
    retry_on_connect_failures,
    ADDRESS_IN_USE,
    CONNECT_TIMEOUT,
)

# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests

if platform == "darwin":
    LOOPBACK = "lo0"
else:
    LOOPBACK = "lo"

DEFAULT_HOSTNAME = "localhost"

torch.backends.cuda.matmul.allow_tf32 = False


def gpus_for_rank(world_size):
    """Multigpu tests are designed to simulate the multi nodes with multi
    GPUs on each node. Nccl backend requires equal #GPUs in each process.
    On a single node, all visible GPUs are evenly
    divided to subsets, each process only uses a subset.
    """
    visible_devices = list(range(torch.cuda.device_count()))
    gpus_per_process = torch.cuda.device_count() // world_size
    gpus_for_rank = []
    for rank in range(world_size):
        gpus_for_rank.append(
            visible_devices[rank * gpus_per_process: (rank + 1) * gpus_per_process]
        )
    return gpus_for_rank


class StoreTestBase(object):
    def _create_store(self, i):
        raise RuntimeError("not implemented")

    def _test_set_get(self, fs):
        fs.add("key", 1)
        fs.add("key", 2)
        fs.add("key", 3)
        fs.set("key0", "value0")
        fs.add("key3", 1)
        fs.set("key1", "value1")
        fs.add("key3", 2)
        fs.set("key2", "value2")
        fs.add("key3", 3)
        fs.add("key3", 4)
        fs.add("key3", 5)
        fs.add("key3", 6)
        self.assertEqual(fs.num_keys(), self.num_keys_total)
        self.assertEqual(b"6", fs.get("key"))
        self.assertEqual(b"value0", fs.get("key0"))
        self.assertEqual(b"value1", fs.get("key1"))
        self.assertEqual(b"value2", fs.get("key2"))
        self.assertEqual(b"21", fs.get("key3"))

        fs.set("-key3", "7")
        self.assertEqual(b"7", fs.get("-key3"))
        fs.delete_key("-key3")
        self.assertEqual(fs.num_keys(), self.num_keys_total)

    def test_set_get(self):
        self._test_set_get(self._create_store())

    def _test_compare_set(self, store):
        missing_key_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
        self.assertEqual(b"wrong_old_value", missing_key_result)

        store.set("cs_key0", "value0")
        self.assertEqual(b"value0", store.get("cs_key0"))
        old_value_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
        self.assertEqual(b"value0", old_value_result)
        self.assertEqual(b"value0", store.get("cs_key0"))
        new_value_result = store.compare_set("cs_key0", "value0", "new_value0")
        self.assertEqual(b"new_value0", new_value_result)
        self.assertEqual(b"new_value0", store.get("cs_key0"))
        empty_old_value_result = store.compare_set("cs_key1", "", "new_value1")
        self.assertEqual(b"new_value1", empty_old_value_result)
        self.assertEqual(b"new_value1", store.get("cs_key1"))

    def test_compare_set(self):
        self._test_compare_set(self._create_store())

    # This is the number of keys used in test_set_get. Adding this as a class
    # property instead of hardcoding in the test since some Store
    # implementations will have differing number of keys. In the base case,
    # there will be 5 keys: key, key0, key1, key2, key3.
    @property
    def num_keys_total(self):
        return 5


class FileStoreTest(TestCase, StoreTestBase):
    def setUp(self):
        super(FileStoreTest, self).setUp()
        self.file = tempfile.NamedTemporaryFile(delete=False)

    def _create_store(self):
        store = dist.FileStore(self.file.name, 1)
        store.set_timeout(timedelta(seconds=300))
        return store


@skip_if_win32()
class HashStoreTest(TestCase, StoreTestBase):
    def setUp(self):
        super(HashStoreTest, self).setUp()

    def _create_store(self):
        store = dist.HashStore()
        store.set_timeout(timedelta(seconds=300))
        return store

class PrefixStoreTest(TestCase):
    def setUp(self):
        # delete is false as FileStore will automatically clean up the file
        self.file = tempfile.NamedTemporaryFile(delete=False)

    def test_get_underlying_store(self):
        tcp_store = dist.TCPStore(host_name=DEFAULT_HOSTNAME, port=0, world_size=1, is_master=True)
        hash_store = dist.HashStore()
        file_store = dist.FileStore(self.file.name, world_size=1)
        for store in [tcp_store, hash_store, file_store]:
            with self.subTest(f"Testing getting underlying_store for {type(store)}"):
                prefix_store = dist.PrefixStore("prefix", store)
                self.assertEqual(prefix_store.underlying_store, store)

class PrefixFileStoreTest(TestCase, StoreTestBase):
    def setUp(self):
        super(PrefixFileStoreTest, self).setUp()
        self.file = tempfile.NamedTemporaryFile(delete=False)
        self.filestore = dist.FileStore(self.file.name, 1)
        self.prefix = "test_prefix"
        self.filestore.set_timeout(timedelta(seconds=300))

    def _create_store(self):
        return dist.PrefixStore(self.prefix, self.filestore)


class TCPStoreTest(TestCase, StoreTestBase):
    def _create_store(self):
        store = create_tcp_store()
        store.set_timeout(timedelta(seconds=300))
        return store

    def test_address_already_in_use(self):
        err_msg_reg = "^The server socket has failed to listen on any local "
        with self.assertRaisesRegex(RuntimeError, err_msg_reg):
            addr = DEFAULT_HOSTNAME
            port = common.find_free_port()

            # Use noqa to silence flake8.
            # Need to store in an unused variable here to ensure the first
            # object is not destroyed before the second object is created.
            store1 = dist.TCPStore(addr, port, 1, True)  # noqa: F841
            store2 = dist.TCPStore(addr, port, 1, True)  # noqa: F841

    @retry_on_connect_failures
    def test_multitenancy(self):
        addr = DEFAULT_HOSTNAME
        port = common.find_free_port()

        # Use noqa to silence flake8.
        # Need to store in an unused variable here to ensure the first
        # object is not destroyed before the second object is created.
        store1 = dist.TCPStore(addr, port, 1, True, multi_tenant=True)  # type: ignore[call-arg] # noqa: F841
        store2 = dist.TCPStore(addr, port, 1, True, multi_tenant=True)  # type: ignore[call-arg] # noqa: F841

    @skip_if_win32()
    @retry_on_connect_failures
    def test_init_pg_and_rpc_with_same_socket(self):
        addr = DEFAULT_HOSTNAME
        port = common.find_free_port()

        os.environ["MASTER_ADDR"] = addr
        os.environ["MASTER_PORT"] = str(port)

        # We internally use a multi-tenant TCP store. Both PG and RPC should successfully
        # initialize even when using the same socket address.

        dist.init_process_group(
            backend="gloo",
            init_method="env://",
            rank=0,
            world_size=1,
        )

        backend_opts = rpc.TensorPipeRpcBackendOptions(
            init_method=f"tcp://{addr}:{port}"
        )
        rpc.init_rpc(
            name="worker0",
            rank=0,
            world_size=1,
            rpc_backend_options=backend_opts,
        )

        rpc.shutdown()

    # The TCPStore has 6 keys in test_set_get. It contains the 5 keys added by
    # the user and one additional key used for coordinate all the workers.
    @property
    def num_keys_total(self):
        return 6

    def _test_numkeys_delkeys(self, fs):
        # We start off with one init key in the store to coordinate workers
        self.assertEqual(fs.num_keys(), 1)
        fs.add("key", 1)
        fs.add("key", 2)
        fs.add("key", 3)
        fs.set("key0", "value0")
        fs.add("key3", 1)
        fs.set("key1", "value1")
        self.assertEqual(fs.num_keys(), 5)
        fs.delete_key("key")
        self.assertEqual(fs.num_keys(), 4)
        fs.set_timeout(timedelta(seconds=2))
        with self.assertRaises(RuntimeError):
            fs.get("key")
        fs.delete_key("key0")
        fs.delete_key("key3")
        self.assertEqual(fs.num_keys(), 2)
        fs.set("key4", "value2")
        self.assertEqual(fs.num_keys(), 3)
        self.assertEqual(b"value1", fs.get("key1"))
        self.assertEqual(b"value2", fs.get("key4"))

    def test_numkeys_delkeys(self):
        self._test_numkeys_delkeys(self._create_store())

    def _create_client(self, index, addr, port, world_size):
        client_store = dist.TCPStore(addr, port, world_size=world_size, timeout=timedelta(seconds=10))
        self.assertEqual("value".encode(), client_store.get("key"))
        client_store.set(f"new_key{index}", f"new_value{index}")
        self.assertEqual(f"next_value{index}".encode(),
                         client_store.compare_set(f"new_key{index}", f"new_value{index}", f"next_value{index}"))

    def _multi_worker_helper(self, world_size):
        addr = DEFAULT_HOSTNAME
        server_store = create_tcp_store(addr, world_size, wait_for_workers=False)
        server_store.set("key", "value")
        port = server_store.port

        num_indices = world_size if world_size else 1
        for i in range(num_indices):
            self._create_client(i, addr, port, world_size)

    def test_multi_worker_with_fixed_world_size(self):
        self._multi_worker_helper(5)

    def test_multi_worker_with_nonfixed_world_size(self):
        self._multi_worker_helper(None)

class PrefixTCPStoreTest(TestCase, StoreTestBase):
    def setUp(self):
        super(PrefixTCPStoreTest, self).setUp()
        self.tcpstore = create_tcp_store()
        self.prefix = "test_prefix"
        self.tcpstore.set_timeout(timedelta(seconds=300))

    def _create_store(self):
        return dist.PrefixStore(self.prefix, self.tcpstore)

    # The PrefixTCPStore has 6 keys in test_set_get. It contains the 5 keys
    # added by the user and one additional key used for coordinate all the
    # workers.
    @property
    def num_keys_total(self):
        return 6


class MyPythonStore(dist.Store):
    def __init__(self):
        super(MyPythonStore, self).__init__()
        self.store = {}

    def set(self, key, value):
        if not isinstance(key, string_classes):
            raise AssertionError("Expected set to be called with string key")
        if type(value) is not bytes:
            raise AssertionError("Expected set to be called with bytes value")
        self.store[key] = value

    def get(self, key):
        value = self.store.get(key, b"")
        if type(value) is not bytes:
            raise AssertionError("Expected get to return bytes value")
        return value

    def add(self, key, value):
        new = int(self.store.get(key, 0)) + value
        self.set(key, bytes(str(new).encode("utf-8")))
        return new


class PythonStoreTest(TestCase):
    def setUp(self):
        super(PythonStoreTest, self).setUp()

    def test_set_get(self):
        # If we were to inherit from StoreTestBase and try to use
        # its test_set_get function, we would exercise the Python
        # API directly, instead of going through the C++ trampoline.
        # We care about testing the C++ trampoline, so run the
        # equivalent of StoreTestBase.test_set_get from C++.
        # See `torch/csrc/distributed/c10d/init.cpp` for the definition
        # of this test function.
        dist._test_python_store(MyPythonStore())


class RendezvousTest(TestCase):
    def test_unknown_handler(self):
        with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"):
            dist.rendezvous("invalid://")

    def test_url_with_node_params(self):
        with self.assertRaisesRegex(AssertionError, "has node-specific arguments"):
            dist.rendezvous("file://foo?rank=12&world_size=16", 12, 16)


class RendezvousEnvTest(TestCase):
    @retry_on_connect_failures
    def test_nominal(self):
        os.environ["WORLD_SIZE"] = "1"
        os.environ["MASTER_ADDR"] = "127.0.0.1"
        os.environ["MASTER_PORT"] = str(common.find_free_port())

        # Single rank
        os.environ["RANK"] = "0"
        gen0 = dist.rendezvous("env://")
        store0, rank0, size0 = next(gen0)
        self.assertEqual(0, rank0)
        self.assertEqual(1, size0)

        store0.set("key0", "value0")

        # check with get
        self.assertEqual(b"value0", store0.get("key0"))


class RendezvousFileTest(TestCase):
    def test_common_errors(self):
        with self.assertRaisesRegex(ValueError, "path missing"):
            gen = dist.rendezvous("file://?rank=0&world_size=1")
            next(gen)
        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
            gen = dist.rendezvous("file:///tmp/foo?world_size=1")
            next(gen)
        with self.assertRaisesRegex(ValueError, "size parameter missing"):
            gen = dist.rendezvous("file:///tmp/foo?rank=0")
            next(gen)

    def test_nominal(self):
        with tempfile.NamedTemporaryFile(delete=False) as file:
            url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2'
            gen0 = dist.rendezvous(url + "&rank=0")
            store0, rank0, size0 = next(gen0)
            self.assertEqual(0, rank0)
            self.assertEqual(2, size0)
            gen1 = dist.rendezvous(url + "&rank=1")
            store1, rank1, size1 = next(gen1)
            self.assertEqual(1, rank1)
            self.assertEqual(2, size1)

            # Set value on both stores
            store0.set("key0", "value0")
            store1.set("key1", "value1")

            # Cross check with get
            self.assertEqual(b"value0", store1.get("key0"))
            self.assertEqual(b"value1", store0.get("key1"))


@skip_if_win32()
class RendezvousTCPTest(TestCase):
    def create_tcp_url(self):
        addr = DEFAULT_HOSTNAME
        port = common.find_free_port()
        url = "tcp://%s:%d?world_size=%d" % (addr, port, 1)
        return url

    def test_common_errors(self):
        with self.assertRaisesRegex(ValueError, "port number missing"):
            gen = dist.rendezvous("tcp://127.0.0.1?rank=0&world_size=1")
            next(gen)
        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
            gen = dist.rendezvous("tcp://127.0.0.1:23456?world_size=1")
            next(gen)
        with self.assertRaisesRegex(ValueError, "size parameter missing"):
            gen = dist.rendezvous("tcp://127.0.0.1:23456?rank=0")
            next(gen)

    def test_dns_timeout(self):
        with self.assertRaisesRegex(TimeoutError, "client socket has timed out after.*dnsnotexist"):
            gen = dist.rendezvous(
                "tcp://dnsnotexist:23456?world_size=2&rank=0",
                timeout=timedelta(seconds=1),
            )
            next(gen)

    @retry_on_connect_failures
    def test_nominal(self):
        url = self.create_tcp_url()
        gen0 = dist.rendezvous(url + "&rank=0")
        store0, rank0, size0 = next(gen0)
        self.assertEqual(0, rank0)
        self.assertEqual(1, size0)

        # Set value on the single store
        store0.set("key0", "value0")

        # check with get
        self.assertEqual(b"value0", store0.get("key0"))

    @retry_on_connect_failures(connect_errors=(CONNECT_TIMEOUT, ADDRESS_IN_USE))
    def test_tcp_store_timeout_set(self):
        url = self.create_tcp_url()
        test_store_timeout = timedelta(seconds=10)
        gen0 = dist.rendezvous(url + "&rank=0", timeout=test_store_timeout)
        store0, rank0, size0 = next(gen0)
        # this should time out in 10s. If the timeout passed into rendezvous was
        # not respected, it will take much longer to timeout.
        start = time.time()
        with self.assertRaisesRegex(RuntimeError, "Timeout"):
            store0.get("nonexistant key")

        end = time.time()
        time_diff = end - start
        self.assertGreater(test_store_timeout.seconds * 10, time_diff)


if __name__ == "__main__":
    assert (
        not torch.cuda._initialized
    ), "test_distributed must not have initialized CUDA context on main process"

    run_tests()
