File: transpose.f

package info (click to toggle)
libxsmm 1.17-4
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 14,976 kB
  • sloc: ansic: 119,587; cpp: 27,680; fortran: 9,179; sh: 5,765; makefile: 5,040; pascal: 2,312; python: 1,812; f90: 1,773
file content (152 lines) | stat: -rw-r--r-- 5,617 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
!=======================================================================!
! Copyright (c) Intel Corporation - All rights reserved.                !
! This file is part of the LIBXSMM library.                             !
!                                                                       !
! For information on the license, see the LICENSE file.                 !
! Further information: https://github.com/hfp/libxsmm/                  !
! SPDX-License-Identifier: BSD-3-Clause                                 !
!=======================================================================!
! Hans Pabst (Intel Corp.)
!=======================================================================!

      PROGRAM transpose
        USE :: LIBXSMM, ONLY: LIBXSMM_BLASINT_KIND,                     &
     &                        libxsmm_timer_duration,                   &
     &                        libxsmm_timer_tick,                       &
     &                        libxsmm_otrans_omp,                       &
     &                        libxsmm_otrans,                           &
     &                        libxsmm_itrans,                           &
     &                        ptr => libxsmm_ptr
        IMPLICIT NONE

        INTEGER, PARAMETER :: T = KIND(0D0)
        INTEGER, PARAMETER :: S = 8

        REAL(T), ALLOCATABLE, TARGET :: a1(:), b1(:)
        !DIR$ ATTRIBUTES ALIGN:64 :: a1, b1
        INTEGER(LIBXSMM_BLASINT_KIND) :: m, n, ldi, ldo, i, j, k
        REAL(T), POINTER :: an(:,:), bn(:,:), bt(:,:)
        DOUBLE PRECISION :: duration
        INTEGER(8) :: nbytes, start
        INTEGER :: nrepeat
        REAL(T) :: diff

        CHARACTER(32) :: argv
        CHARACTER :: trans
        INTEGER :: argc

        argc = COMMAND_ARGUMENT_COUNT()
        IF (1 <= argc) THEN
          CALL GET_COMMAND_ARGUMENT(1, trans)
        ELSE
          trans = 'o'
        END IF
        IF (2 <= argc) THEN
          CALL GET_COMMAND_ARGUMENT(2, argv)
          READ(argv, "(I32)") m
        ELSE
          m = 4096
        END IF
        IF (3 <= argc) THEN
          CALL GET_COMMAND_ARGUMENT(3, argv)
          READ(argv, "(I32)") n
        ELSE
          n = m
        END IF
        IF (4 <= argc) THEN
          CALL GET_COMMAND_ARGUMENT(4, argv)
          READ(argv, "(I32)") ldi
        ELSE
          ldi = m
        END IF
        IF (5 <= argc) THEN
          CALL GET_COMMAND_ARGUMENT(5, argv)
          READ(argv, "(I32)") ldo
        ELSE
          ldo = ldi
        END IF
        IF (6 <= argc) THEN
          CALL GET_COMMAND_ARGUMENT(6, argv)
          READ(argv, "(I32)") nrepeat
        ELSE
          nrepeat = 3
        END IF

        nbytes = INT(m * n, 8) * T ! size in Byte
        WRITE(*, "(2(A,I0),2(A,I0),A,I0,A)")                            &
     &    "m=", m, " n=", n, " ldi=", ldi, " ldo=", ldo,                &
     &    " size=", (nbytes / ISHFT(1, 20)), "MB"

        ALLOCATE(b1(ldo*MAX(m,n)))
        bn(1:ldo,1:n) => b1
        bt(1:ldo,1:m) => b1

        IF (('o'.EQ.trans).OR.('O'.EQ.trans)) THEN
          ALLOCATE(a1(ldi*n))
          an(1:ldi,1:n) => a1
          !$OMP PARALLEL DO PRIVATE(i, j) DEFAULT(NONE) SHARED(m, n, an)
          DO j = 1, n
            DO i = 1, m
              an(i,j) = initial_value(i - 1, j - 1, m)
            END DO
          END DO
          !$OMP END PARALLEL DO
          start = libxsmm_timer_tick()
          DO k = 1, nrepeat
            !CALL libxsmm_otrans_omp(ptr(b1), ptr(a1), S, m, n, ldi, ldo)
            !CALL libxsmm_otrans(ptr(b1), ptr(a1), S, m, n, ldi, ldo)
            !CALL libxsmm_otrans(bn, an, m, n, ldi, ldo)
            CALL libxsmm_otrans(b1, a1, m, n, ldi, ldo)
          END DO
          duration = libxsmm_timer_duration(start, libxsmm_timer_tick())
          DEALLOCATE(a1)
        ELSE ! in-place
          !$OMP PARALLEL DO PRIVATE(i, j) DEFAULT(NONE) SHARED(m, n, bn)
          DO j = 1, n
            DO i = 1, m
              bn(i,j) = initial_value(i - 1, j - 1, m)
            END DO
          END DO
          !$OMP END PARALLEL DO
          start = libxsmm_timer_tick()
          DO k = 1, nrepeat
            !CALL libxsmm_itrans(ptr(b1), S, m, n, ldi, ldo)
            !CALL libxsmm_itrans(bn, m, n, ldi)
            CALL libxsmm_itrans(b1, m, n, ldi)
          END DO
          duration = libxsmm_timer_duration(start, libxsmm_timer_tick())
        END IF

        diff = REAL(0, T)
        DO j = 1, n
          DO i = 1, m
            diff = MAX(diff,                                            &
     &                ABS(bt(j,i) - initial_value(i - 1, j - 1, m)))
          END DO
        END DO
        DEALLOCATE(b1)

        IF (0.GE.diff) THEN
          IF ((0.LT.duration).AND.(0.LT.nrepeat)) THEN
            ! out-of-place transpose bandwidth assumes RFO
            WRITE(*, "(1A,A,F10.1,A)") CHAR(9), "bandwidth:  ",         &
     &        REAL(nbytes, T)                                           &
     &        * MERGE(3D0, 2D0, ('o'.EQ.trans).OR.('O'.EQ.trans))       &
     &        * REAL(nrepeat, T) / (duration * REAL(ISHFT(1_8, 30), T)),&
     &        " GB/s"
            WRITE(*, "(1A,A,F10.1,A)") CHAR(9), "duration:   ",         &
     &        1D3 * duration / REAL(nrepeat, T),                        &
     &        " ms"
          END IF
        ELSE
          WRITE(*,*) "Validation failed!"
          STOP 1
        END IF

      CONTAINS
        PURE REAL(T) FUNCTION initial_value(i, j, m)
          INTEGER(LIBXSMM_BLASINT_KIND), INTENT(IN) :: i, j, m
          initial_value = REAL(j * m + i, T)
        END FUNCTION
      END PROGRAM