File: onnxruntime_compile_triton_kernel.cmake

package info (click to toggle)
onnxruntime 1.21.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 333,732 kB
  • sloc: cpp: 3,153,079; python: 179,219; ansic: 109,131; asm: 37,791; cs: 34,424; perl: 13,070; java: 11,047; javascript: 6,330; pascal: 4,126; sh: 3,277; xml: 598; objc: 281; makefile: 59
file content (35 lines) | stat: -rw-r--r-- 1,455 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

find_package(Python3 COMPONENTS Interpreter REQUIRED)

# set all triton kernel ops that need to be compiled
if(onnxruntime_USE_ROCM)
  set(triton_kernel_scripts
      "onnxruntime/core/providers/rocm/math/softmax_triton.py"
      "onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py"
  )
endif()

function(compile_triton_kernel out_triton_kernel_obj_file out_triton_kernel_header_dir)
  # compile triton kernel, generate .a and .h files
  set(triton_kernel_compiler "${REPO_ROOT}/tools/ci_build/compile_triton.py")
  set(out_dir "${CMAKE_CURRENT_BINARY_DIR}/triton_kernels")
  set(out_obj_file "${out_dir}/triton_kernel_infos.a")
  set(header_file "${out_dir}/triton_kernel_infos.h")

  list(TRANSFORM triton_kernel_scripts PREPEND "${REPO_ROOT}/")

  add_custom_command(
    OUTPUT ${out_obj_file} ${header_file}
    COMMAND Python3::Interpreter ${triton_kernel_compiler}
            --header ${header_file}
            --script_files ${triton_kernel_scripts}
            --obj_file ${out_obj_file}
    DEPENDS ${triton_kernel_scripts} ${triton_kernel_compiler}
    COMMENT "Triton compile generates: ${out_obj_file}"
  )
  add_custom_target(onnxruntime_triton_kernel DEPENDS ${out_obj_file} ${header_file})
  set(${out_triton_kernel_obj_file} ${out_obj_file} PARENT_SCOPE)
  set(${out_triton_kernel_header_dir} ${out_dir} PARENT_SCOPE)
endfunction()