microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
wechi/ort_test

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

cmake/ext_cuda.cmake

40lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4find_package(CUDAToolkit)
5enable_language(CUDA)
6
7set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
8set(CMAKE_CUDA_STANDARD 17)
9include(CMakeDependentOption)
10cmake_dependent_option(USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32" OFF)
11option(USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)
12if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
13 message( STATUS "Turn off flash attention and memory efficient attention since CUDA compiler version < 11.6")
14 set(USE_FLASH_ATTENTION OFF)
15 set(USE_MEMORY_EFFICIENT_ATTENTION OFF)
16endif()
17
18set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
19if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11)
20 set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --Werror default-stream-launch")
21endif()
22
23if(NOT WIN32)
24 list(APPEND CUDA_NVCC_FLAGS --compiler-options -fPIC)
25endif()
26
27# Options passed to cudafe
28set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=bad_friend_decl\"")
29set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=unsigned_compare_with_zero\"")
30set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no_effect\"")
31
32add_compile_definitions(USE_CUDA)
33if (USE_FLASH_ATTENTION)
34 message( STATUS "Enable flash attention")
35 add_compile_definitions(USE_FLASH_ATTENTION)
36endif()
37if (USE_MEMORY_EFFICIENT_ATTENTION)
38 message( STATUS "Enable memory efficient attention")
39 add_compile_definitions(USE_MEMORY_EFFICIENT_ATTENTION)
40endif()
41