File: redscat2.c

package info (click to toggle)
mpich 4.3.0%2Breally4.2.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 419,120 kB
  • sloc: ansic: 1,215,557; cpp: 74,755; javascript: 40,763; f90: 20,649; sh: 18,463; xml: 14,418; python: 14,397; perl: 13,772; makefile: 9,279; fortran: 8,063; java: 4,553; asm: 324; ruby: 176; lisp: 19; php: 8; sed: 4
file content (128 lines) | stat: -rw-r--r-- 3,334 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
/*
 * Copyright (C) by Argonne National Laboratory
 *     See COPYRIGHT in top-level directory
 */

/*
 * Test of reduce scatter.
 *
 * Checks that non-commutative operations are not commuted and that
 * all of the operations are performed.
 *
 * Can be called with any number of processors.
 */

#include "mpi.h"
#include <stdio.h>
#include <stdlib.h>
#include "mpitest.h"

int err = 0;

/* left(x,y) ==> x */
void left(void *a, void *b, int *count, MPI_Datatype * type);
void left(void *a, void *b, int *count, MPI_Datatype * type)
{
    int *in = a;
    int *inout = b;
    int i;

    for (i = 0; i < *count; ++i) {
        if (in[i] > inout[i])
            ++err;
        inout[i] = in[i];
    }
}

/* right(x,y) ==> y */
void right(void *a, void *b, int *count, MPI_Datatype * type);
void right(void *a, void *b, int *count, MPI_Datatype * type)
{
    int *in = a;
    int *inout = b;
    int i;

    for (i = 0; i < *count; ++i) {
        if (in[i] > inout[i])
            ++err;
        inout[i] = inout[i];
    }
}

/* Just performs a simple sum but can be marked as non-commutative to
   potentially trigger different logic in the implementation. */
void nc_sum(void *a, void *b, int *count, MPI_Datatype * type);
void nc_sum(void *a, void *b, int *count, MPI_Datatype * type)
{
    int *in = a;
    int *inout = b;
    int i;

    for (i = 0; i < *count; ++i) {
        inout[i] = in[i] + inout[i];
    }
}

#define MAX_BLOCK_SIZE 256

int main(int argc, char **argv)
{
    int *sendbuf, *recvcounts;
    int block_size;
    int *recvbuf;
    int size, rank, i;
    MPI_Comm comm;
    MPI_Op left_op, right_op, nc_sum_op;

    MTest_Init(&argc, &argv);
    comm = MPI_COMM_WORLD;

    MPI_Comm_size(comm, &size);
    MPI_Comm_rank(comm, &rank);

    MPI_Op_create(&left, 0 /*non-commutative */ , &left_op);
    MPI_Op_create(&right, 0 /*non-commutative */ , &right_op);
    MPI_Op_create(&nc_sum, 0 /*non-commutative */ , &nc_sum_op);

    for (block_size = 1; block_size < MAX_BLOCK_SIZE; block_size *= 2) {
        sendbuf = (int *) malloc(block_size * size * sizeof(int));
        recvbuf = malloc(block_size * sizeof(int));

        for (i = 0; i < (size * block_size); i++)
            sendbuf[i] = rank + i;
        for (i = 0; i < block_size; i++)
            recvbuf[i] = 0xdeadbeef;
        recvcounts = (int *) malloc(size * sizeof(int));
        for (i = 0; i < size; i++)
            recvcounts[i] = block_size;

        MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, MPI_INT, left_op, comm);
        for (i = 0; i < block_size; ++i)
            if (recvbuf[i] != (rank * block_size + i))
                ++err;

        MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, MPI_INT, right_op, comm);
        for (i = 0; i < block_size; ++i)
            if (recvbuf[i] != ((size - 1) + (rank * block_size) + i))
                ++err;

        MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, MPI_INT, nc_sum_op, comm);
        for (i = 0; i < block_size; ++i) {
            int x = rank * block_size + i;
            if (recvbuf[i] != (size * x + (size - 1) * size / 2))
                ++err;
        }

        free(recvbuf);
        free(sendbuf);
        free(recvcounts);
    }

    MPI_Op_free(&left_op);
    MPI_Op_free(&right_op);
    MPI_Op_free(&nc_sum_op);

    MTest_Finalize(err);

    return err;
}