diff --git a/include/infinicore/adaptor/aten_adaptor.hpp b/include/infinicore/adaptor/aten_adaptor.hpp index 70cb98e18..386dd3ee4 100644 --- a/include/infinicore/adaptor/aten_adaptor.hpp +++ b/include/infinicore/adaptor/aten_adaptor.hpp @@ -5,9 +5,12 @@ #include -#ifdef ENABLE_NVIDIA_API +#if defined(ENABLE_NVIDIA_API) #include #include +#elif defined(ENABLE_HYGON_API) +#include +#include #endif namespace infinicore::adaptor { @@ -29,7 +32,8 @@ inline at::ScalarType to_at_dtype(DataType dtype) { } inline at::Device to_at_device(const Device &device) { - if (device.getType() == Device::Type::NVIDIA) { + if (device.getType() == Device::Type::NVIDIA + || device.getType() == Device::Type::HYGON) { return at::Device(at::kCUDA, device.getIndex()); } else if (device.getType() == Device::Type::CPU) { return at::Device(at::kCPU); @@ -40,8 +44,14 @@ inline at::Device to_at_device(const Device &device) { at::Tensor to_aten_tensor(const infinicore::Tensor &t); -#ifdef ENABLE_NVIDIA_API -c10::cuda::CUDAStream get_cuda_stream(); +#if defined(ENABLE_HYGON_API) +using TorchStream = c10::hip::HIPStream; +using TorchStreamGuard = c10::hip::HIPStreamGuard; +TorchStream get_cuda_stream(); +#elif defined(ENABLE_NVIDIA_API) +using TorchStream = c10::cuda::CUDAStream; +using TorchStreamGuard = c10::cuda::CUDAStreamGuard; +TorchStream get_cuda_stream(); #endif } // namespace infinicore::adaptor diff --git a/include/infinicore/adaptor/flash_attention_adaptor.hpp b/include/infinicore/adaptor/flash_attention_adaptor.hpp index 8a9e152fd..a3cb43c97 100644 --- a/include/infinicore/adaptor/flash_attention_adaptor.hpp +++ b/include/infinicore/adaptor/flash_attention_adaptor.hpp @@ -110,5 +110,46 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits); +#ifdef ENABLE_HYGON_API +// Hygon-specific wrappers resolved via dlsym at runtime. + +std::vector +vllm_mha_varlen_fwd(at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + std::optional &out_, + const at::Tensor &cu_seqlens_q, + const at::Tensor &cu_seqlens_k, + std::optional &seqused_k, + std::optional &leftpad_k_, + std::optional &block_table_, + std::optional &alibi_slopes_, + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_); + +void paged_attention(at::Tensor &out, + at::Tensor &q, + at::Tensor &k_cache, + at::Tensor &v_cache, + double scale, + at::Tensor &block_table, + at::Tensor &context_lens, + std::optional alibi_slopes, + const std::string &kv_cache_dtype, + std::optional q_scale, + std::optional k_scale, + std::optional v_scale, + int max_context_len); +#endif // ENABLE_HYGON_API + } // namespace flash #endif // ENABLE_FLASH_ATTN diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 0b3eb9655..bd5061c67 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -76,6 +76,17 @@ zeros, ) +# Re-attempt flash_attn preload after _infinicore.so (and its DTK/HIP +# dependencies) are loaded. The initial preload() above may have failed +# because flash_attn_2_cuda.so's dependencies were not yet available. +# Now that _infinicore.so is loaded we can call dlopen(path, RTLD_GLOBAL) +# to upgrade the already-loaded library from RTLD_LOCAL to RTLD_GLOBAL, +# making its symbols visible to dlsym(RTLD_DEFAULT). +with contextlib.suppress(Exception): + from ._preload import preload_flash_attn + + preload_flash_attn() + __all__ = [ # Modules. "context", diff --git a/python/infinicore/_preload.py b/python/infinicore/_preload.py index fc5ff6560..fdafdddca 100644 --- a/python/infinicore/_preload.py +++ b/python/infinicore/_preload.py @@ -1,5 +1,8 @@ import ctypes +import glob +import importlib.util import os +import sys from typing import Iterable, List @@ -63,12 +66,94 @@ def preload_hpcc() -> None: _try_load(prefixes, lib) +def preload_torch_hip() -> None: + """ + Best-effort preload of torch HIP runtime libs with RTLD_GLOBAL. + + This helps external extensions resolve c10::hip symbols when they are + not recorded as direct DT_NEEDED dependencies. + """ + spec = importlib.util.find_spec("torch") + if spec is None or not spec.origin: + return + torch_dir = os.path.dirname(spec.origin) + torch_libdir = os.path.join(torch_dir, "lib") + if not os.path.isdir(torch_libdir): + return + + libs = [ + "libtorch_global_deps.so", + "libc10.so", + "libc10_hip.so", + "libtorch_cpu.so", + "libtorch.so", + "libtorch_hip.so", + ] + for lib in libs: + full = os.path.join(torch_libdir, lib) + if os.path.exists(full): + try: + ctypes.CDLL(full, mode=ctypes.RTLD_GLOBAL) + except OSError: + # Best-effort preload, continue on errors. + pass + + +def preload_flash_attn() -> None: + """ + Best-effort preload of flash_attn_2_cuda extension with RTLD_GLOBAL. + + InfiniCore hygon wrapper resolves C symbols like `mha_varlen_fwd` from the + flash-attn extension at runtime via dlsym(RTLD_DEFAULT, ...). The symbols + only need to be available when the operator is actually called, not at + library load time. So this preload is a convenience — if it fails, the + symbols will be resolved later when torch + flash_attn are imported by + the application (e.g. InfiniLM). + """ + candidates: List[str] = [] + from_env = os.getenv("FLASH_ATTN_PREBUILT") + if from_env: + if os.path.isfile(from_env): + candidates.append(from_env) + elif os.path.isdir(from_env): + candidates.extend(glob.glob(os.path.join(from_env, "flash_attn_2_cuda*.so"))) + + # Try resolving via Python import metadata. + spec = importlib.util.find_spec("flash_attn_2_cuda") + if spec and spec.origin and os.path.exists(spec.origin): + candidates.append(spec.origin) + + # Fallback: scan python paths for extension module. + for p in sys.path: + if not p: + continue + candidates.extend(glob.glob(os.path.join(p, "flash_attn_2_cuda*.so"))) + + # Common installation locations. + candidates.extend(glob.glob("/usr/local/lib/python*/dist-packages/flash_attn_2_cuda*.so")) + candidates.extend(glob.glob("/root/.infini/lib/flash_attn_2_cuda*.so")) + + seen = set() + for so_path in candidates: + if not so_path or so_path in seen: + continue + seen.add(so_path) + if not os.path.exists(so_path): + continue + try: + ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) + return + except OSError: + continue + + def _should_preload_device(device_type: str) -> bool: """ Check if preload is needed for a specific device type. """ device_env_map = { "METAX": ["HPCC_PATH", "INFINICORE_PRELOAD_HPCC"], # HPCC/METAX + "HYGON": ["DTK_ROOT", "INFINICORE_PRELOAD_TORCH_HIP"], # Add other device types here as needed: # "ASCEND": ["ASCEND_PATH"], # "CAMBRICON": ["NEUWARE_HOME"], @@ -90,6 +175,8 @@ def preload_device(device_type: str) -> None: """ if device_type == "METAX": preload_hpcc() + elif device_type == "HYGON": + preload_torch_hip() # Add other device preload functions here as needed: # elif device_type == "ASCEND": # preload_ascend() @@ -103,9 +190,20 @@ def preload() -> None: This function detects available device types and preloads their runtime libraries if the environment indicates they are needed. """ + # Always try torch HIP preload first (best-effort, no-op if torch/HIP is absent). + try: + preload_torch_hip() + except Exception: + pass + try: + preload_flash_attn() + except Exception: + pass + # Device types that may require preload device_types = [ "METAX", # HPCC/METAX + "HYGON", # Add other device types here as they are implemented: # "ASCEND", # "CAMBRICON", diff --git a/src/infinicore/adaptor/aten_adaptor.cc b/src/infinicore/adaptor/aten_adaptor.cc index 2edbe3f8f..f219eb051 100644 --- a/src/infinicore/adaptor/aten_adaptor.cc +++ b/src/infinicore/adaptor/aten_adaptor.cc @@ -32,8 +32,13 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) { options); } -#ifdef ENABLE_NVIDIA_API -c10::cuda::CUDAStream get_cuda_stream() { +#if defined(ENABLE_HYGON_API) +TorchStream get_cuda_stream() { + return c10::hip::getStreamFromExternal( + hipStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex()); +} +#elif defined(ENABLE_NVIDIA_API) +TorchStream get_cuda_stream() { return c10::cuda::getStreamFromExternal( cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex()); } diff --git a/src/infinicore/adaptor/flash_attn_hygon_wrapper.cc b/src/infinicore/adaptor/flash_attn_hygon_wrapper.cc new file mode 100644 index 000000000..2ee24617b --- /dev/null +++ b/src/infinicore/adaptor/flash_attn_hygon_wrapper.cc @@ -0,0 +1,250 @@ +#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_HYGON_API) && !defined(ENABLE_NVIDIA_API) + +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Function pointer types for the extern "C" functions exported by the DCU +// flash_attn shared library (built from flash-attention-cutlass-master). +// We resolve these at runtime via dlsym to avoid hard link-time dependency +// on the prebuilt .so (which requires libtorch_python.so). +// --------------------------------------------------------------------------- + +using mha_fwd_kvcache_fn_t = std::vector (*)( + at::Tensor &q, + const at::Tensor &kcache, + const at::Tensor &vcache, + c10::optional &k_, + c10::optional &v_, + c10::optional &seqlens_k_, + c10::optional &rotary_cos_, + c10::optional &rotary_sin_, + c10::optional &cache_batch_idx_, + c10::optional &leftpad_k_, + c10::optional &block_table_, + c10::optional &alibi_slopes_, + c10::optional &out_, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + bool is_rotary_interleaved, + int num_splits, + const c10::optional &s_aux_); + +using mha_varlen_fwd_fn_t = std::vector (*)( + at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + c10::optional &out_, + const at::Tensor &cu_seqlens_q, + const at::Tensor &cu_seqlens_k, + c10::optional &seqused_k, + c10::optional &leftpad_k_, + c10::optional &block_table_, + c10::optional &alibi_slopes_, + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + c10::optional q_descale_, + c10::optional k_descale_, + c10::optional v_descale_, + c10::optional gen_, + const c10::optional &s_aux_); + +static void *resolve_symbol(const char *name) { + void *sym = dlsym(RTLD_DEFAULT, name); + if (sym) { + return sym; + } + throw std::runtime_error( + std::string("flash_attn symbol not found: ") + name + + ". Ensure flash_attn_2_cuda is loaded before calling this function " + "(e.g. import torch; import flash_attn_2_cuda)."); +} + +// --------------------------------------------------------------------------- +// Wrappers in the flash:: namespace. +// These match the signatures declared in +// include/infinicore/adaptor/flash_attention_adaptor.hpp +// and bridge the namespace gap between InfiniCore and the DCU library. +// --------------------------------------------------------------------------- + +namespace flash { + +std::vector +mha_fwd_kvcache(at::Tensor &q, + const at::Tensor &kcache, + const at::Tensor &vcache, + std::optional &k_, + std::optional &v_, + std::optional &seqlens_k_, + std::optional &rotary_cos_, + std::optional &rotary_sin_, + std::optional &cache_batch_idx_, + std::optional &leftpad_k_, + std::optional &block_table_, + std::optional &alibi_slopes_, + std::optional &out_, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + bool is_rotary_interleaved, + int num_splits) { + static auto fn = reinterpret_cast( + resolve_symbol("mha_fwd_kvcache")); + c10::optional s_aux = c10::nullopt; + return fn( + q, kcache, vcache, + k_, v_, seqlens_k_, + rotary_cos_, rotary_sin_, cache_batch_idx_, leftpad_k_, + block_table_, alibi_slopes_, out_, + softmax_scale, is_causal, + window_size_left, window_size_right, + softcap, is_rotary_interleaved, num_splits, s_aux); +} + +std::vector +mha_varlen_fwd(at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + std::optional &out_, + const at::Tensor &cu_seqlens_q, + const at::Tensor &cu_seqlens_k, + std::optional &seqused_k, + std::optional &leftpad_k_, + std::optional &block_table_, + std::optional &alibi_slopes_, + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_) { + static auto fn = reinterpret_cast( + resolve_symbol("mha_varlen_fwd")); + c10::optional q_descale = c10::nullopt; + c10::optional k_descale = c10::nullopt; + c10::optional v_descale = c10::nullopt; + c10::optional s_aux = c10::nullopt; + return fn( + q, k, v, out_, + cu_seqlens_q, cu_seqlens_k, + seqused_k, leftpad_k_, block_table_, alibi_slopes_, + max_seqlen_q, max_seqlen_k, + p_dropout, softmax_scale, zero_tensors, is_causal, + window_size_left, window_size_right, + softcap, return_softmax, + q_descale, k_descale, v_descale, gen_, s_aux); +} + +// --------------------------------------------------------------------------- +// vllm_mha_varlen_fwd — same signature as mha_varlen_fwd but resolves a +// different symbol. Used for prefill with paged KV cache on Hygon. +// --------------------------------------------------------------------------- + +std::vector +vllm_mha_varlen_fwd(at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + std::optional &out_, + const at::Tensor &cu_seqlens_q, + const at::Tensor &cu_seqlens_k, + std::optional &seqused_k, + std::optional &leftpad_k_, + std::optional &block_table_, + std::optional &alibi_slopes_, + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_) { + static auto fn = reinterpret_cast( + resolve_symbol("vllm_mha_varlen_fwd")); + c10::optional q_descale = c10::nullopt; + c10::optional k_descale = c10::nullopt; + c10::optional v_descale = c10::nullopt; + c10::optional s_aux = c10::nullopt; + return fn( + q, k, v, out_, + cu_seqlens_q, cu_seqlens_k, + seqused_k, leftpad_k_, block_table_, alibi_slopes_, + max_seqlen_q, max_seqlen_k, + p_dropout, softmax_scale, zero_tensors, is_causal, + window_size_left, window_size_right, + softcap, return_softmax, + q_descale, k_descale, v_descale, gen_, s_aux); +} + +// --------------------------------------------------------------------------- +// paged_attention — decode-only paged attention for Hygon. +// Signature: (out, q, k_cache, v_cache, scale, block_table, context_lens, +// alibi_slopes, kv_cache_dtype, q_scale, k_scale, v_scale, +// max_context_len, s_aux) -> void +// --------------------------------------------------------------------------- + +using paged_attention_fn_t = void (*)( + at::Tensor &out, + at::Tensor &q, + at::Tensor &k_cache, + at::Tensor &v_cache, + double scale, + at::Tensor &block_table, + at::Tensor &context_lens, + const c10::optional &alibi_slopes, + const std::string &kv_cache_dtype, + const c10::optional &q_scale, + const c10::optional &k_scale, + const c10::optional &v_scale, + int max_context_len, + const c10::optional &s_aux); + +void paged_attention(at::Tensor &out, + at::Tensor &q, + at::Tensor &k_cache, + at::Tensor &v_cache, + double scale, + at::Tensor &block_table, + at::Tensor &context_lens, + std::optional alibi_slopes, + const std::string &kv_cache_dtype, + std::optional q_scale, + std::optional k_scale, + std::optional v_scale, + int max_context_len) { + static auto fn = reinterpret_cast( + resolve_symbol("paged_attention")); + c10::optional s_aux = c10::nullopt; + fn(out, q, k_cache, v_cache, scale, block_table, context_lens, + alibi_slopes, kv_cache_dtype, q_scale, k_scale, v_scale, + max_context_len, s_aux); +} + +} // namespace flash + +#endif // ENABLE_FLASH_ATTN && ENABLE_HYGON_API && !ENABLE_NVIDIA_API diff --git a/src/infinicore/graph/graph.cc b/src/infinicore/graph/graph.cc index 3b8fc57e5..cfd0e5839 100644 --- a/src/infinicore/graph/graph.cc +++ b/src/infinicore/graph/graph.cc @@ -37,7 +37,7 @@ struct Graph::DeviceGraph { infinirtGraphNode_t node; std::vector log_buffer; - DeviceGraph() { + DeviceGraph() : graph(nullptr), exec(nullptr), node(nullptr) { log_buffer.resize(4 * 1024); } @@ -111,6 +111,7 @@ void Graph::instantiate() { warned_once = true; spdlog::warn("Fail to instantiate device graph: {}", std::string(device_graph_.get()->log_buffer.data())); } + device_graph_.reset(); } } diff --git a/src/infinicore/nn/embedding.cc b/src/infinicore/nn/embedding.cc index fde40c761..ff4ecb398 100644 --- a/src/infinicore/nn/embedding.cc +++ b/src/infinicore/nn/embedding.cc @@ -45,7 +45,7 @@ Embedding::Embedding(size_t num_embeddings, Tensor Embedding::forward(const Tensor &indices) const { // TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach auto device_type = device_.getType(); - if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI) { + if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI || device_type == Device::Type::HYGON) { // Use op::embedding which supports device-side input and batch dimension return op::embedding(indices->contiguous()->to(device_), weight_); } diff --git a/src/infinicore/nn/rmsnorm.cc b/src/infinicore/nn/rmsnorm.cc index bc703300f..9886b6223 100644 --- a/src/infinicore/nn/rmsnorm.cc +++ b/src/infinicore/nn/rmsnorm.cc @@ -31,7 +31,8 @@ void RMSNorm::forward_inplace(Tensor &x, Tensor &residual) const { || device_.getType() == Device::Type::ILUVATAR || device_.getType() == Device::Type::METAX || device_.getType() == Device::Type::MOORE - || device_.getType() == Device::Type::ALI) { + || device_.getType() == Device::Type::ALI + || device_.getType() == Device::Type::HYGON) { op::add_rms_norm_inplace(x, residual, weight_, static_cast(eps_)); } else { op::add_(residual, x, residual); diff --git a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc index 24fcf7aea..ce06bd21f 100644 --- a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc +++ b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc @@ -33,13 +33,13 @@ void *plan(Tensor out, void run(void *planned_meta) { #ifdef ENABLE_FLASH_ATTN - c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); + infinicore::adaptor::TorchStreamGuard guard(infinicore::adaptor::get_cuda_stream()); auto *p = reinterpret_cast(planned_meta); auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out); auto q = infinicore::adaptor::to_aten_tensor(p->q); - auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache); - auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache); + auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache).contiguous(); + auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache).contiguous(); auto seqlens_k = std::optional(infinicore::adaptor::to_aten_tensor(p->seqlens_k)); auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table)); auto alibi_slopes = p->alibi_slopes diff --git a/src/infinicore/ops/mha_kvcache/mha_kvcache_hygon_paged.cc b/src/infinicore/ops/mha_kvcache/mha_kvcache_hygon_paged.cc new file mode 100644 index 000000000..8f5d82240 --- /dev/null +++ b/src/infinicore/ops/mha_kvcache/mha_kvcache_hygon_paged.cc @@ -0,0 +1,122 @@ +// Hygon DCU decode attention backend using flash::paged_attention. +// +// Replaces mha_fwd_kvcache (which calls torch::empty for softmax_lse, +// breaking HIP graph capture) with paged_attention (static allocations only). +// +// paged_attention requires V in transposed layout [num_blocks, nkv, hd, block_size]. +// We pre-allocate the transposed buffer in plan() and do an ATen copy in run() +// from the standard V cache [num_blocks, nkv, block_size, hd]. + +#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_HYGON_API) && !defined(ENABLE_NVIDIA_API) + +#include "infinicore/ops/mha_kvcache.hpp" + +#include "infinicore/adaptor/flash_attention_adaptor.hpp" + +#include + +namespace infinicore::op::mha_kvcache_impl::hygon_paged { + +struct PlannedMeta { + graph::GraphTensor out, q, k_cache, v_cache, seqlens_k, block_table; + std::optional alibi_slopes; + float scale; + // Pre-allocated buffer for transposed V: [num_blocks, nkv, hd, block_size] + graph::GraphTensor v_transposed_buf; +}; + +void *plan(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &seqlens_k, + const Tensor &block_table, + std::optional alibi_slopes, + float scale) { + // v_cache arrives permuted: [num_blocks, block_size, nkv, hd] + // paged_attention needs V transposed: [num_blocks, nkv, hd, block_size] + auto vs = v_cache->shape(); // {num_blocks, block_size, nkv, hd} + auto v_transposed = Tensor::empty( + {vs[0], vs[2], vs[3], vs[1]}, // {num_blocks, nkv, hd, block_size} + v_cache->dtype(), + v_cache->device()); + + return new PlannedMeta{ + graph::GraphTensor(out), + graph::GraphTensor(q), + graph::GraphTensor(k_cache), + graph::GraphTensor(v_cache), + graph::GraphTensor(seqlens_k), + graph::GraphTensor(block_table), + alibi_slopes ? std::optional(graph::GraphTensor(*alibi_slopes)) : std::nullopt, + scale, + graph::GraphTensor(v_transposed)}; +} + +void run(void *planned_meta) { + infinicore::adaptor::TorchStreamGuard guard(infinicore::adaptor::get_cuda_stream()); + auto *p = reinterpret_cast(planned_meta); + + auto out_t = infinicore::adaptor::to_aten_tensor(p->out); + auto q_t = infinicore::adaptor::to_aten_tensor(p->q); + // InfiniLM passes K/V in permuted layout [num_blocks, block_size, nkv, hd]. + // paged_attention expects K: [num_blocks, nkv, block_size, hd] + // Permute back; double-permute restores original contiguous strides (no-op copy). + auto k_t = infinicore::adaptor::to_aten_tensor(p->k_cache).permute({0, 2, 1, 3}).contiguous(); + auto v_raw = infinicore::adaptor::to_aten_tensor(p->v_cache); + auto v_transposed = infinicore::adaptor::to_aten_tensor(p->v_transposed_buf); + auto seqlens = infinicore::adaptor::to_aten_tensor(p->seqlens_k); + auto block_tbl = infinicore::adaptor::to_aten_tensor(p->block_table); + + // Transpose V on GPU: + // v_raw: [num_blocks, block_size, nkv, hd] (permuted from InfiniLM) + // target: [num_blocks, nkv, hd, block_size] + v_transposed.copy_(v_raw.permute({0, 2, 3, 1})); + + // Both q and out are 4D: [num_seqs, seqlen, num_heads, head_dim] + // paged_attention accesses query.size(2) for num_heads and query.size(3) for head_size, + // so query must be 4D (despite misleading comment in the source). + + // Compute a safe upper bound for max_context_len from tensor shapes. + // Using seqlens.max().item() would require a D2H transfer that breaks + // HIP graph capture. paged_attention uses per-sequence seqlens internally, + // so a larger bound only wastes a few grid blocks. + int block_size = static_cast(k_t.size(2)); // after permute-back: [num_blocks, nkv, block_size, hd] + int max_blocks_per_seq = static_cast(block_tbl.size(1)); + int max_context_len = max_blocks_per_seq * block_size; + + std::optional alibi = std::nullopt; + std::optional q_scale = std::nullopt; + std::optional k_scale = std::nullopt; + std::optional v_scale = std::nullopt; + + flash::paged_attention( + out_t, // out: [num_seqs, 1, num_heads, head_dim] (4D) + q_t, // query: [num_seqs, 1, num_heads, head_dim] (4D) + k_t, // [num_blocks, nkv, block_size, hd] + v_transposed, // [num_blocks, nkv, hd, block_size] + p->scale, + block_tbl, + seqlens, + alibi, + std::string("auto"), + q_scale, k_scale, v_scale, + max_context_len); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +// Register for Hygon device only, overriding the ALLDEVICE flashattn registration. +static bool registered = []() { + MhaKVCache::plan_dispatcher().registerDevice(Device::Type::HYGON, &plan, true); + MhaKVCache::run_dispatcher().registerDevice(Device::Type::HYGON, &run, true); + MhaKVCache::cleanup_dispatcher().registerDevice(Device::Type::HYGON, &cleanup, true); + return true; +}(); + +} // namespace infinicore::op::mha_kvcache_impl::hygon_paged + +#endif // ENABLE_FLASH_ATTN && ENABLE_HYGON_API && !ENABLE_NVIDIA_API diff --git a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc index aff085898..514f33009 100644 --- a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc +++ b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc @@ -41,18 +41,25 @@ void *plan(Tensor out, void run(void *planned_meta) { #ifdef ENABLE_FLASH_ATTN - c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); + infinicore::adaptor::TorchStreamGuard guard(infinicore::adaptor::get_cuda_stream()); auto *p = reinterpret_cast(planned_meta); auto q = infinicore::adaptor::to_aten_tensor(p->q); - auto k = infinicore::adaptor::to_aten_tensor(p->k); - auto v = infinicore::adaptor::to_aten_tensor(p->v); + auto k = infinicore::adaptor::to_aten_tensor(p->k).contiguous(); + auto v = infinicore::adaptor::to_aten_tensor(p->v).contiguous(); auto out = std::optional(infinicore::adaptor::to_aten_tensor(p->out)); auto cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q); auto cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k); + auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table)); + + // Flash-attn requires cu_seqlens and block_table on same device as q/k/v. + auto device = q.device(); + if (!cu_seqlens_q.is_cuda()) cu_seqlens_q = cu_seqlens_q.to(device); + if (!cu_seqlens_kv.is_cuda()) cu_seqlens_kv = cu_seqlens_kv.to(device); + if (block_table.has_value() && !block_table->is_cuda()) block_table = block_table->to(device); + std::optional seqused_k = std::nullopt; std::optional leftpad_k = std::nullopt; - auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table)); auto max_seqlen_q = p->max_seqlen_q; auto max_seqlen_k = p->max_seqlen_k; auto alibi_slopes = p->alibi_slopes ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) : std::nullopt; diff --git a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_hygon_vllm.cc b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_hygon_vllm.cc new file mode 100644 index 000000000..b3326d443 --- /dev/null +++ b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_hygon_vllm.cc @@ -0,0 +1,102 @@ +// Hygon DCU prefill attention backend using flash::vllm_mha_varlen_fwd. +// +// Replaces the generic mha_varlen_fwd with the vLLM variant which is +// optimized for paged KV-cache prefill on Hygon DCU. + +#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_HYGON_API) && !defined(ENABLE_NVIDIA_API) + +#include "infinicore/ops/mha_varlen.hpp" + +#include "infinicore/adaptor/flash_attention_adaptor.hpp" + +#include + +namespace infinicore::op::mha_varlen_impl::hygon_vllm { + +struct PlannedMeta { + graph::GraphTensor out, q, k, v, cum_seqlens_q, cum_seqlens_k, block_table; + int max_seqlen_q, max_seqlen_k; + std::optional alibi_slopes; + float scale; +}; + +void *plan(Tensor out, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &cum_seqlens_q, + const Tensor &cum_seqlens_k, + const Tensor &block_table, + int max_seqlen_q, + int max_seqlen_k, + std::optional alibi_slopes, + float scale) { + return new PlannedMeta{ + graph::GraphTensor(out), + graph::GraphTensor(q), + graph::GraphTensor(k), + graph::GraphTensor(v), + graph::GraphTensor(cum_seqlens_q), + graph::GraphTensor(cum_seqlens_k), + graph::GraphTensor(block_table), + max_seqlen_q, + max_seqlen_k, + alibi_slopes ? std::optional(graph::GraphTensor(*alibi_slopes)) : std::nullopt, + scale}; +} + +void run(void *planned_meta) { + infinicore::adaptor::TorchStreamGuard guard(infinicore::adaptor::get_cuda_stream()); + auto *p = reinterpret_cast(planned_meta); + + auto q = infinicore::adaptor::to_aten_tensor(p->q); + // InfiniLM passes K/V in permuted layout [num_blocks, block_size, nkv, hd]. + // vllm_mha_varlen_fwd expects: + // K: [num_blocks, nkv, block_size, hd] — standard cache layout + // V: [num_blocks, nkv, hd, block_size] — transposed V layout + // K: permute back {0,2,1,3}; double-permute restores contiguous strides (no-op copy). + auto k = infinicore::adaptor::to_aten_tensor(p->k).permute({0, 2, 1, 3}).contiguous(); + // V: from [num_blocks, block_size, nkv, hd] → [num_blocks, nkv, hd, block_size] + auto v = infinicore::adaptor::to_aten_tensor(p->v).permute({0, 2, 3, 1}).contiguous(); + auto out = std::optional(infinicore::adaptor::to_aten_tensor(p->out)); + auto cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q); + auto cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k); + auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table)); + + auto device = q.device(); + if (!cu_seqlens_q.is_cuda()) cu_seqlens_q = cu_seqlens_q.to(device); + if (!cu_seqlens_kv.is_cuda()) cu_seqlens_kv = cu_seqlens_kv.to(device); + if (block_table.has_value() && !block_table->is_cuda()) block_table = block_table->to(device); + + std::optional seqused_k = std::nullopt; + std::optional leftpad_k = std::nullopt; + auto alibi_slopes = p->alibi_slopes + ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) + : std::nullopt; + + flash::vllm_mha_varlen_fwd( + q, k, v, out, + cu_seqlens_q, cu_seqlens_kv, + seqused_k, leftpad_k, block_table, alibi_slopes, + p->max_seqlen_q, p->max_seqlen_k, + 0.0f, p->scale, false, true, + -1, -1, 0.0f, false, + std::nullopt); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +// Register for Hygon device only, overriding the ALLDEVICE flashattn registration. +static bool registered = []() { + MultiheadAttentionVarlen::plan_dispatcher().registerDevice(Device::Type::HYGON, &plan, true); + MultiheadAttentionVarlen::run_dispatcher().registerDevice(Device::Type::HYGON, &run, true); + MultiheadAttentionVarlen::cleanup_dispatcher().registerDevice(Device::Type::HYGON, &cleanup, true); + return true; +}(); + +} // namespace infinicore::op::mha_varlen_impl::hygon_vllm + +#endif // ENABLE_FLASH_ATTN && ENABLE_HYGON_API && !ENABLE_NVIDIA_API diff --git a/src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh b/src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh index 305820862..e9b3121aa 100644 --- a/src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh +++ b/src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh @@ -1143,8 +1143,7 @@ __device__ void flashAttentionDecodeCtaPipelinedKernel( // Prefetch the very first token. int buf = 0; - int t_base = 0; - int token_in_block = 0; + (void)0; // t_base, token_in_block removed (unused) int logical_block = 0; { if (tid == 0) { diff --git a/src/infiniop/ops/paged_attention/operator.cc b/src/infiniop/ops/paged_attention/operator.cc index 89b48a473..3c9882964 100644 --- a/src/infiniop/ops/paged_attention/operator.cc +++ b/src/infiniop/ops/paged_attention/operator.cc @@ -2,7 +2,7 @@ #include "../../handle.h" #include "infiniop/ops/paged_attention.h" -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) #include "nvidia/paged_attention_nvidia.cuh" #endif #ifdef ENABLE_MOORE_API @@ -48,6 +48,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedAttentionDescriptor( #endif #ifdef ENABLE_ILUVATAR_API CREATE(INFINI_DEVICE_ILUVATAR, nvidia) +#endif +#ifdef ENABLE_HYGON_API + CREATE(INFINI_DEVICE_HYGON, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -78,6 +81,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize( #endif #ifdef ENABLE_ILUVATAR_API GET(INFINI_DEVICE_ILUVATAR, nvidia) +#endif +#ifdef ENABLE_HYGON_API + GET(INFINI_DEVICE_HYGON, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -112,6 +118,9 @@ __INFINI_C infiniStatus_t infiniopPagedAttention( #endif #ifdef ENABLE_ILUVATAR_API CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia) +#endif +#ifdef ENABLE_HYGON_API + CALCULATE(INFINI_DEVICE_HYGON, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -141,6 +150,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedAttentionDescriptor( #endif #ifdef ENABLE_ILUVATAR_API DESTROY(INFINI_DEVICE_ILUVATAR, nvidia) +#endif +#ifdef ENABLE_HYGON_API + DESTROY(INFINI_DEVICE_HYGON, nvidia) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh b/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh index da6e1810c..f1441590c 100644 --- a/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh +++ b/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh @@ -2306,9 +2306,10 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel( } __syncthreads(); -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && !defined(ENABLE_HYGON_API) // WMMA: each warp computes scores for 16 keys (one 16-column slice of the K tile) across all 16 rows. // For kBlockN=64, only the first 4 warps participate in WMMA score computation. + // nvcuda::wmma is NVIDIA-only; HIP/ROCm does not support it. namespace wmma = nvcuda::wmma; constexpr int kNSub = kBlockN / 16; if (warp_id < kNSub) { diff --git a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu index 100b5bc43..7b1772ad6 100644 --- a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu +++ b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu @@ -1,4 +1,4 @@ -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) #include #include diff --git a/src/infiniop/ops/paged_attention_prefill/operator.cc b/src/infiniop/ops/paged_attention_prefill/operator.cc index 36804cfff..9aa4a39c1 100644 --- a/src/infiniop/ops/paged_attention_prefill/operator.cc +++ b/src/infiniop/ops/paged_attention_prefill/operator.cc @@ -2,7 +2,7 @@ #include "../../handle.h" #include "infiniop/ops/paged_attention_prefill.h" -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) #include "nvidia/paged_attention_prefill_nvidia.cuh" #endif #ifdef ENABLE_METAX_API @@ -48,6 +48,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( #ifdef ENABLE_ILUVATAR_API CREATE(INFINI_DEVICE_ILUVATAR, nvidia) #endif +#ifdef ENABLE_HYGON_API + CREATE(INFINI_DEVICE_HYGON, nvidia) +#endif #ifdef ENABLE_MOORE_API CREATE(INFINI_DEVICE_MOORE, moore) #endif @@ -78,6 +81,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize( #ifdef ENABLE_ILUVATAR_API GET(INFINI_DEVICE_ILUVATAR, nvidia) #endif +#ifdef ENABLE_HYGON_API + GET(INFINI_DEVICE_HYGON, nvidia) +#endif #ifdef ENABLE_MOORE_API GET(INFINI_DEVICE_MOORE, moore) #endif @@ -115,6 +121,9 @@ __INFINI_C infiniStatus_t infiniopPagedAttentionPrefill( #ifdef ENABLE_ILUVATAR_API CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia) #endif +#ifdef ENABLE_HYGON_API + CALCULATE(INFINI_DEVICE_HYGON, nvidia) +#endif #ifdef ENABLE_MOORE_API CALCULATE(INFINI_DEVICE_MOORE, moore) #endif @@ -144,6 +153,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor( #ifdef ENABLE_ILUVATAR_API DESTROY(INFINI_DEVICE_ILUVATAR, nvidia) #endif +#ifdef ENABLE_HYGON_API + DESTROY(INFINI_DEVICE_HYGON, nvidia) +#endif #ifdef ENABLE_MOORE_API DESTROY(INFINI_DEVICE_MOORE, moore) #endif diff --git a/src/infiniop/ops/paged_caching/operator.cc b/src/infiniop/ops/paged_caching/operator.cc index 99a0a1076..831c6ffcd 100644 --- a/src/infiniop/ops/paged_caching/operator.cc +++ b/src/infiniop/ops/paged_caching/operator.cc @@ -2,7 +2,7 @@ #include "../../handle.h" #include "infiniop/ops/paged_caching.h" -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) #include "nvidia/paged_caching_nvidia.cuh" #endif #ifdef ENABLE_METAX_API @@ -41,6 +41,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedCachingDescriptor( #ifdef ENABLE_ILUVATAR_API CREATE(INFINI_DEVICE_ILUVATAR, nvidia) #endif +#ifdef ENABLE_HYGON_API + CREATE(INFINI_DEVICE_HYGON, nvidia) +#endif #ifdef ENABLE_MOORE_API CREATE(INFINI_DEVICE_MOORE, moore) #endif @@ -71,6 +74,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedCachingWorkspaceSize( #ifdef ENABLE_ILUVATAR_API GET(INFINI_DEVICE_ILUVATAR, nvidia) #endif +#ifdef ENABLE_HYGON_API + GET(INFINI_DEVICE_HYGON, nvidia) +#endif #ifdef ENABLE_MOORE_API GET(INFINI_DEVICE_MOORE, moore) #endif @@ -105,6 +111,9 @@ __INFINI_C infiniStatus_t infiniopPagedCaching( #ifdef ENABLE_ILUVATAR_API CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia) #endif +#ifdef ENABLE_HYGON_API + CALCULATE(INFINI_DEVICE_HYGON, nvidia) +#endif #ifdef ENABLE_MOORE_API CALCULATE(INFINI_DEVICE_MOORE, moore) #endif @@ -134,6 +143,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedCachingDescriptor( #ifdef ENABLE_ILUVATAR_API DESTROY(INFINI_DEVICE_ILUVATAR, nvidia) #endif +#ifdef ENABLE_HYGON_API + DESTROY(INFINI_DEVICE_HYGON, nvidia) +#endif #ifdef ENABLE_MOORE_API DESTROY(INFINI_DEVICE_MOORE, moore) #endif diff --git a/test/infinicore/ops/mha_varlen.py b/test/infinicore/ops/mha_varlen.py index 942595782..56b5650e8 100644 --- a/test/infinicore/ops/mha_varlen.py +++ b/test/infinicore/ops/mha_varlen.py @@ -15,13 +15,16 @@ TestCase, ) +# gfx936 (Hygon DCU) paged attention only supports page_block_size=64 +_BLOCK_SIZE = 64 if "--hygon" in sys.argv else 256 + # Test Cases: (num_heads, num_kv_heads, head_size, block_size, [request_batch]) _TEST_CASES_DATA = [ - (1, 1, 128, 256, [(250,), (7,)]), - (4, 4, 128, 256, [(250,), (7,)]), - (1, 1, 128, 256, [(260, 73), (1, 1)]), - (8, 2, 128, 256, [(250,), (7,)]), - (8, 2, 128, 256, [(260, 73), (1, 1)]), + (1, 1, 128, _BLOCK_SIZE, [(250,), (7,)]), + (4, 4, 128, _BLOCK_SIZE, [(250,), (7,)]), + (1, 1, 128, _BLOCK_SIZE, [(260, 73), (1, 1)]), + (8, 2, 128, _BLOCK_SIZE, [(250,), (7,)]), + (8, 2, 128, _BLOCK_SIZE, [(260, 73), (1, 1)]), ] _MAX_SEQUENCE_LENGTH = 8192 diff --git a/xmake.lua b/xmake.lua index 507aa45f8..a6faa5c74 100644 --- a/xmake.lua +++ b/xmake.lua @@ -200,6 +200,8 @@ option_end() if has_config("hygon-dcu") then add_defines("ENABLE_HYGON_API") + -- Required by HIP headers included from torch ATen/hip. + add_defines("__HIP_PLATFORM_AMD__") includes("xmake/hygon.lua") end @@ -240,9 +242,20 @@ option("flash-attn") set_description("Path to flash-attention repo. If not set, flash-attention will not used.") option_end() +option("flash-attn-prebuilt") + set_default("") + set_showmenu(true) + set_description("Path to prebuilt flash_attn .so file or directory containing it. Used for Hygon DCU.") +option_end() + if has_config("aten") then add_defines("ENABLE_ATEN") - if get_config("flash-attn") ~= false then + local fa_src = get_config("flash-attn") + local fa_prebuilt = get_config("flash-attn-prebuilt") + if not fa_prebuilt or fa_prebuilt == "" then + fa_prebuilt = os.getenv("FLASH_ATTN_PREBUILT") + end + if (fa_src and fa_src ~= "") or (fa_prebuilt and fa_prebuilt ~= "") then add_defines("ENABLE_FLASH_ATTN") end end @@ -469,14 +482,94 @@ target("infinicore_cpp_api") end end - before_build(function (target) + if has_config("hygon-dcu") then + local cuda_sdk = get_config("cuda") or os.getenv("CUDA_HOME") or os.getenv("CUDA_PATH") + local dtk_root = os.getenv("DTK_ROOT") or "/opt/dtk" + local function normalize_cuda_root(root) + if not root or root == "" or not os.isdir(root) then + return nil + end + if os.isdir(path.join(root, "include")) then + return root + end + local nested = { + path.join(root, "cuda"), + path.join(root, "cuda-12") + } + for _, cand in ipairs(nested) do + if os.isdir(path.join(cand, "include")) then + return cand + end + end + return root + end + + -- Prefer xmake --cuda=... for deterministic SDK include/link paths. + local normalized_cuda_sdk = normalize_cuda_root(cuda_sdk) + if normalized_cuda_sdk then + add_includedirs(path.join(normalized_cuda_sdk, "include")) + add_linkdirs(path.join(normalized_cuda_sdk, "lib64")) + end + + -- Keep DTK fallback paths for environments where only DTK_ROOT is set. + if dtk_root and dtk_root ~= "" and os.isdir(dtk_root) then + add_includedirs(path.join(dtk_root, "include")) + add_includedirs(path.join(dtk_root, "cuda", "include")) + add_linkdirs(path.join(dtk_root, "lib")) + add_linkdirs(path.join(dtk_root, "cuda", "lib64")) + end + end + + on_load(function (target) if has_config("aten") then + -- Hygon DCU: link prebuilt flash_attn BEFORE torch for correct symbol resolution order + if has_config("hygon-dcu") then + local fa_prebuilt = get_config("flash-attn-prebuilt") + if not fa_prebuilt or fa_prebuilt == "" then + fa_prebuilt = os.getenv("FLASH_ATTN_PREBUILT") + end + + local flash_so_dir = nil + local flash_so_name = nil + + if fa_prebuilt and fa_prebuilt ~= "" then + if os.isfile(fa_prebuilt) then + flash_so_dir = path.directory(fa_prebuilt) + flash_so_name = path.filename(fa_prebuilt) + else + flash_so_dir = fa_prebuilt + local files = os.files(path.join(fa_prebuilt, "flash_attn_2_cuda*.so")) + if #files > 0 then + flash_so_name = path.filename(files[1]) + end + end + else + local ok, so_path = pcall(function() + return os.iorunv("python", {"-c", "import flash_attn_2_cuda; print(flash_attn_2_cuda.__file__)"}):trim() + end) + if ok and so_path and so_path ~= "" and os.isfile(so_path) then + flash_so_dir = path.directory(so_path) + flash_so_name = path.filename(so_path) + end + end + + if flash_so_dir and flash_so_name then + target:add("linkdirs", flash_so_dir) + target:add("ldflags", "-Wl,--no-as-needed", {force = true}) + target:add("ldflags", "-l:" .. flash_so_name, {force = true}) + target:add("ldflags", "-Wl,--as-needed", {force = true}) + print("Flash Attention library: " .. path.join(flash_so_dir, flash_so_name)) + end + end + local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim() local TORCH_DIR = outdata + -- Use sysincludedirs (-isystem) so that torch's bundled pybind11 headers + -- do not shadow the xmake pybind11 package headers. target:add( - "includedirs", - path.join(TORCH_DIR, "include"), + "sysincludedirs", + path.join(TORCH_DIR, "include"), path.join(TORCH_DIR, "include/torch/csrc/api/include"), { public = true }) @@ -485,14 +578,40 @@ target("infinicore_cpp_api") path.join(TORCH_DIR, "lib"), { public = true } ) - target:add( - "links", - "torch", - "c10", - "torch_cuda", - "c10_cuda", - { public = true } - ) + local torch_libdir = path.join(TORCH_DIR, "lib") + target:add("rpathdirs", torch_libdir) + target:add("ldflags", "-Wl,--no-as-needed", {force = true}) + local torch_links = {"torch", "c10"} + local function has_torch_lib(name) + return #os.files(path.join(torch_libdir, "lib" .. name .. ".so*")) > 0 + end + if has_torch_lib("torch_cuda") then + table.insert(torch_links, "torch_cuda") + elseif has_torch_lib("torch_hip") then + table.insert(torch_links, "torch_hip") + end + if has_torch_lib("c10_cuda") then + table.insert(torch_links, "c10_cuda") + elseif has_torch_lib("c10_hip") then + table.insert(torch_links, "c10_hip") + end + target:add("links", table.unpack(torch_links), { public = true }) + -- Hard-pin runtime dependency entries to avoid linker dropping HIP torch libs. + target:add("ldflags", "-L" .. torch_libdir, {force = true}) + if has_torch_lib("torch_hip") then + target:add("ldflags", "-l:libtorch_hip.so", {force = true}) + end + if has_torch_lib("c10_hip") then + target:add("ldflags", "-l:libc10_hip.so", {force = true}) + end + if has_torch_lib("torch_cuda") then + target:add("ldflags", "-l:libtorch_cuda.so", {force = true}) + end + if has_torch_lib("c10_cuda") then + target:add("ldflags", "-l:libc10_cuda.so", {force = true}) + end + target:add("ldflags", "-Wl,--as-needed", {force = true}) + print("Torch libraries: " .. table.concat(torch_links, ", ")) end end) @@ -515,6 +634,40 @@ target("infinicore_cpp_api") add_installfiles("include/infinicore/(**/*.hpp)",{prefixdir = "include/infinicore"}) add_installfiles("include/infinicore.h", {prefixdir = "include"}) add_installfiles("include/infinicore.hpp", {prefixdir = "include"}) + + after_install(function (target) + if not has_config("hygon-dcu") then return end + local fa_prebuilt = get_config("flash-attn-prebuilt") + if not fa_prebuilt or fa_prebuilt == "" then + fa_prebuilt = os.getenv("FLASH_ATTN_PREBUILT") + end + + local flash_so_path = nil + if fa_prebuilt and fa_prebuilt ~= "" then + if os.isfile(fa_prebuilt) then + flash_so_path = fa_prebuilt + else + local files = os.files(path.join(fa_prebuilt, "flash_attn_2_cuda*.so")) + if #files > 0 then flash_so_path = files[1] end + end + else + local ok, so_path = pcall(function() + return os.iorunv("python", {"-c", "import flash_attn_2_cuda; print(flash_attn_2_cuda.__file__)"}):trim() + end) + if ok and so_path and so_path ~= "" and os.isfile(so_path) then + flash_so_path = so_path + end + end + + if flash_so_path then + local installdir = target:installdir() + local libdir = path.join(installdir, "lib") + os.mkdir(libdir) + os.cp(flash_so_path, libdir) + print("Copied prebuilt flash_attn library to " .. libdir) + end + end) + after_build(function (target) print(YELLOW .. "[Congratulations!] Now you can install the libraries with \"xmake install\"" .. NC) end) target_end() diff --git a/xmake/hygon.lua b/xmake/hygon.lua index 942d33255..da3b1008b 100644 --- a/xmake/hygon.lua +++ b/xmake/hygon.lua @@ -64,14 +64,14 @@ target("infiniop-hygon") -- 添加海光DCU特定的编译标志 -- 检测实际GPU架构,如果未指定则默认使用gfx906 - local hygon_arch = os.getenv("HYGON_ARCH") or "gfx906" + local hygon_arch = os.getenv("HYGON_ARCH") or "gfx936" add_cuflags("-arch=" .. hygon_arch) - print("编译海光DCU架构: " .. hygon_arch) - + print("compile hygon architecture: " .. hygon_arch) + -- 复用NVIDIA的CUDA实现,通过HIP兼容层 add_files("../src/infiniop/devices/nvidia/*.cu", "../src/infiniop/ops/*/nvidia/*.cu") - -- temporarily disble paged ops for hygon - remove_files("../src/infiniop/ops/paged*/nvidia/*.cu") + -- temporarily disable paged ops for hygon (segfault on gfx936, needs HIP adaptation) + -- remove_files("../src/infiniop/ops/paged*/nvidia/*.cu") if has_config("ninetoothed") then add_files("../build/ninetoothed/*.c", "../build/ninetoothed/*.cpp", {cxxflags = {"-Wno-return-type"}}) @@ -107,9 +107,9 @@ target("infinirt-hygon") -- 添加海光DCU特定的编译标志 -- 检测实际GPU架构,如果未指定则默认使用gfx906 - local hygon_arch = os.getenv("HYGON_ARCH") or "gfx906" + local hygon_arch = os.getenv("HYGON_ARCH") or "gfx936" add_cuflags("-arch=" .. hygon_arch) - + add_files("../src/infinirt/cuda/*.cu") target_end() @@ -143,7 +143,7 @@ target("infiniccl-hygon") -- 添加海光DCU特定的编译标志 -- 检测实际GPU架构,如果未指定则默认使用gfx906 - local hygon_arch = os.getenv("HYGON_ARCH") or "gfx906" + local hygon_arch = os.getenv("HYGON_ARCH") or "gfx936" add_cuflags("-arch=" .. hygon_arch) -- 使用NCCL (NVIDIA Collective Communications Library)