File: dispatch_udt.f

package info (click to toggle)
libxsmm 1.17-4
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 14,976 kB
  • sloc: ansic: 119,587; cpp: 27,680; fortran: 9,179; sh: 5,765; makefile: 5,040; pascal: 2,312; python: 1,812; f90: 1,773
file content (76 lines) | stat: -rw-r--r-- 3,722 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
!=======================================================================!
! Copyright (c) Intel Corporation - All rights reserved.                !
! This file is part of the LIBXSMM library.                             !
!                                                                       !
! For information on the license, see the LICENSE file.                 !
! Further information: https://github.com/hfp/libxsmm/                  !
! SPDX-License-Identifier: BSD-3-Clause                                 !
!=======================================================================!
! Hans Pabst (Intel Corp.)
!=======================================================================!
      PROGRAM dispatch_udt
        USE, INTRINSIC :: ISO_C_BINDING,  ONLY: C_PTR, C_LOC,           &
     &                                          C_ASSOCIATED,           &
     &                                          C_F_POINTER
        USE :: LIBXSMM, ONLY: LIBXSMM_BLASINT_KIND,                     &
     &                        LIBXSMM_MMFUNCTION => LIBXSMM_DMMFUNCTION,&
     &                        libxsmm_mmdispatch => libxsmm_dmmdispatch,&
     &                        libxsmm_mmcall => libxsmm_dmmcall,        &
     &                        libxsmm_xregister, libxsmm_xdispatch
        IMPLICIT NONE
        INTEGER, PARAMETER :: T = KIND(0D0)
        INTEGER :: batchsize = 1000, i
        INTEGER(LIBXSMM_BLASINT_KIND) :: j, ki, nrepeat = 100
        INTEGER(LIBXSMM_BLASINT_KIND) :: m = 13, n = 5, k = 7
        REAL(T), ALLOCATABLE :: a(:,:,:), b(:,:,:), c(:,:)
        TYPE(LIBXSMM_MMFUNCTION), TARGET  :: xmm(2) ! array of kernels
        TYPE(LIBXSMM_MMFUNCTION), POINTER :: udt(:)
        INTEGER(LIBXSMM_BLASINT_KIND), TARGET :: key(3)
        TYPE(C_PTR) :: ptr

        ALLOCATE(a(m,k,batchsize), b(k,n,batchsize), c(m,n))
        ! initialize input
        DO i = 1, batchsize
          DO ki = 1, k
            DO j = 1, m
              a(j,ki,i) = REAL(1, T) / REAL(MOD(i+j+ki, 25), T)
            END DO
            DO j = 1, n
              b(ki,j,i) = REAL(7, T) / REAL(MOD(i+j+ki, 75), T)
            END DO
          END DO
        END DO
        c(:,:) = REAL(0, T)

        ! repeat inner part to exercise libxsmm_xdispatch
        DO j = 1, nrepeat
          key = (/m, n, k/) ! setup key
          ! query associated value using key
          ptr = libxsmm_xdispatch(                                      &
     &      C_LOC(key), SIZE(key) * LIBXSMM_BLASINT_KIND)

          IF (C_ASSOCIATED(ptr)) THEN ! value was already registered
            ! convert C-ptr to Fortran POINTER
            CALL C_F_POINTER(ptr, udt, (/SIZE(xmm)/))
          ELSE ! no value registered yet
            ! generate and dispatch a series of kernels
            CALL libxsmm_mmdispatch(xmm(1), m, n, k,                    &
     &        alpha=REAL(1, T), beta=REAL(1, T))
            CALL libxsmm_mmdispatch(xmm(2), m, n, k + 2,                &
     &        alpha=REAL(1, T), beta=REAL(1, T))
            ! register an entry that contains all kernels from above
            ptr = libxsmm_xregister(                                    &
     &        C_LOC(key), SIZE(key) * LIBXSMM_BLASINT_KIND,             &
     &        SIZE(xmm) * 8, C_LOC(xmm))
            ! point udt to xmm (below code uses udt to refer to kernels
            udt => xmm ! alternatively, use C_F_POINTER
          END IF

          ! here we executed libxsmm_xdispatch one time (for this round)
          ! all kernels have been dispatched at once (udt)
          DO i = 1, batchsize
            CALL libxsmm_mmcall(udt(1), a(:,:,i), b(:,:,i), c)
          END DO
        END DO
        DEALLOCATE(a, b, c)
      END PROGRAM