File: test_e2e_tensorpipe.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (63 lines) | stat: -rw-r--r-- 1,935 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <gtest/gtest.h>

#include "e2e_test_base.h"

#include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
#include <torch/csrc/distributed/rpc/request_callback_no_python.h>
#include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
#include <torch/torch.h>

namespace torch {
namespace distributed {
namespace rpc {

#ifdef USE_TENSORPIPE

class TestE2ETensorPipe : public TestE2EBase {
 protected:
  void buildRpcAgent() override {
    auto options = c10d::ProcessGroupGloo::Options::create();
    options->devices.push_back(
        ::c10d::ProcessGroupGloo::createDeviceForHostname(serverAddress));
    float rpcTimeout = 30;

    TensorPipeRpcBackendOptions opts(
        /*numWorkerThreads=*/std::max(16U, std::thread::hardware_concurrency()),
        /*transports=*/nullopt,
        /*channels=*/nullopt,
        /*rpc_timeout=*/rpcTimeout,
        /*init_method=*/"unused");

    rpcAgent = std::make_shared<TensorPipeAgent>(
        store,
        "worker",
        0,
        numWorkers,
        opts,
        std::unordered_map<std::string, DeviceMap>{},
        std::vector<c10::Device>{},
        std::make_unique<RequestCallbackNoPython>());
  }
};

// End to end training loop test in C++ so that we can run LSAN on this test to
// catch memory leaks. Enabling LSAN with python multiprocessing has been
// challenging and we don't have a good solution yet.
TEST_F(TestE2ETensorPipe, TestTrainingLoop) {
  runTrainingLoop();
  // Ensure the tensorpipe internal state is cleared up.
  auto tensorpipeAgent = std::static_pointer_cast<TensorPipeAgent>(rpcAgent);

  // Shutdown RPC agent for all RPCs to clean up.
  tensorpipeAgent->join();
  tensorpipeAgent->shutdown();
  ASSERT_EQ(0, tensorpipeAgent->numPendingResponses());
  ASSERT_EQ(0, tensorpipeAgent->timeoutMapSize());
  ASSERT_EQ(0, tensorpipeAgent->messageIdToTimeoutMapSize());
}

#endif

} // namespace rpc
} // namespace distributed
} // namespace torch