diff --git a/src/common/cuda/kernel_commons.h b/src/common/cuda/kernel_commons.h index e2ef107..32bb7b5 100644 --- a/src/common/cuda/kernel_commons.h +++ b/src/common/cuda/kernel_commons.h @@ -33,11 +33,7 @@ using cuda_bfloat162 = __mt_bfloat162; namespace infini::ops { -constexpr int CUDA_BLOCK_SIZE_128 = 128; -constexpr int CUDA_BLOCK_SIZE_256 = 256; -constexpr int CUDA_BLOCK_SIZE_512 = 512; -constexpr int CUDA_BLOCK_SIZE_1024 = 1024; -constexpr int CUDA_BLOCK_SIZE_2048 = 2048; +using AllCudaBlockSizes = List<128, 256, 512, 1024, 2048>; #if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) // Cache `cudaDeviceProp` per device, initialized once at first access. @@ -76,7 +72,7 @@ inline int QueryMaxThreadsPerBlock() { #elif defined(WITH_METAX) inline int QueryMaxThreadsPerBlock() { // TODO: Add MCR device properties query for Metax. - return CUDA_BLOCK_SIZE_256; + return 256; } #elif defined(WITH_MOORE) inline int QueryMaxThreadsPerBlock() { @@ -91,16 +87,16 @@ inline int QueryMaxThreadsPerBlock() { // Get optimal block size based on GPU hardware architecture. inline int GetOptimalBlockSize() { int max_threads = QueryMaxThreadsPerBlock(); - if (max_threads >= CUDA_BLOCK_SIZE_2048) { - return CUDA_BLOCK_SIZE_2048; - } else if (max_threads >= CUDA_BLOCK_SIZE_1024) { - return CUDA_BLOCK_SIZE_1024; - } else if (max_threads >= CUDA_BLOCK_SIZE_512) { - return CUDA_BLOCK_SIZE_512; - } else if (max_threads >= CUDA_BLOCK_SIZE_256) { - return CUDA_BLOCK_SIZE_256; + if (max_threads >= 2048) { + return 2048; + } else if (max_threads >= 1024) { + return 1024; + } else if (max_threads >= 512) { + return 512; + } else if (max_threads >= 256) { + return 256; } else { - return CUDA_BLOCK_SIZE_128; + return 128; } } diff --git a/src/cuda/add/kernel.h b/src/cuda/add/kernel.h index c174afb..912c0b8 100644 --- a/src/cuda/add/kernel.h +++ b/src/cuda/add/kernel.h @@ -50,10 +50,12 @@ class CudaAdd : public Add { void operator()(const Tensor input, const Tensor other, Tensor out) const override { int block_size = GetOptimalBlockSize(); - DispatchFunc( - out_type_, - [&](auto tag) { - using T = typename decltype(tag)::type; + DispatchFunc( + {static_cast(out_type_), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + auto cuda_stream = static_cast(stream_ ? stream_ : 0); dim3 blockDims( @@ -64,25 +66,11 @@ class CudaAdd : public Add { const T* d_input = reinterpret_cast(input.data()); const T* d_other = reinterpret_cast(other.data()); -#define LAUNCH_ADD_KERNEL(BLOCK_SIZE) \ - AddKernel<<>>( \ - d_out, d_input, d_other, d_out_shape_, d_input_shape_, d_other_shape_, \ - d_out_strides_, d_input_strides_, d_other_strides_, output_size_, ndim_, \ - is_out_contiguous_, is_input_contiguous_, is_other_contiguous_); - - if (block_size == CUDA_BLOCK_SIZE_2048) { - LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_2048) - } else if (block_size == CUDA_BLOCK_SIZE_1024) { - LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_1024) - } else if (block_size == CUDA_BLOCK_SIZE_512) { - LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_512) - } else if (block_size == CUDA_BLOCK_SIZE_256) { - LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_256) - } else { - LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_128) - } - -#undef LAUNCH_ADD_KERNEL + AddKernel<<>>( + d_out, d_input, d_other, d_out_shape_, d_input_shape_, + d_other_shape_, d_out_strides_, d_input_strides_, + d_other_strides_, output_size_, ndim_, is_out_contiguous_, + is_input_contiguous_, is_other_contiguous_); }, "CudaAdd::operator()"); } diff --git a/src/cuda/causal_softmax/kernel.h b/src/cuda/causal_softmax/kernel.h index 924be40..a54f493 100644 --- a/src/cuda/causal_softmax/kernel.h +++ b/src/cuda/causal_softmax/kernel.h @@ -34,32 +34,20 @@ class CudaCausalSoftmax : public CausalSoftmax { int block_size = GetOptimalBlockSize(); - DispatchFunc( - out.dtype(), - [&](auto tag) { - using T = typename decltype(tag)::type; - -#define LAUNCH_CAUSAL_SOFTMAX_KERNEL(BLOCK_SIZE) \ - CausalSoftmaxKernel \ - <<>>( \ - reinterpret_cast(out.data()), \ - reinterpret_cast(input.data()), batch_size_, seq_len_, \ - total_seq_len_, stride_out_batch, stride_out_row, \ - stride_input_batch, stride_input_row); - - if (block_size == CUDA_BLOCK_SIZE_2048) { - LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_2048) - } else if (block_size == CUDA_BLOCK_SIZE_1024) { - LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_1024) - } else if (block_size == CUDA_BLOCK_SIZE_512) { - LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_512) - } else if (block_size == CUDA_BLOCK_SIZE_256) { - LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_256) - } else { - LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_128) - } - -#undef LAUNCH_CAUSAL_SOFTMAX_KERNEL + DispatchFunc, ReducedFloatTypes>, + AllCudaBlockSizes>( + // TODO: output dtype should use the one passed in during construction. + {static_cast(out.dtype()), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + + CausalSoftmaxKernel + <<>>( + reinterpret_cast(out.data()), + reinterpret_cast(input.data()), batch_size_, + seq_len_, total_seq_len_, stride_out_batch, stride_out_row, + stride_input_batch, stride_input_row); }, "CudaCausalSoftmax::operator()"); } diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h index dc28ee5..8a75d75 100644 --- a/src/cuda/rms_norm/kernel.h +++ b/src/cuda/rms_norm/kernel.h @@ -36,32 +36,20 @@ class CudaRmsNorm : public RmsNorm { int block_size = GetOptimalBlockSize(); - DispatchFunc( - out.dtype(), - [&](auto tag) { - using T = typename decltype(tag)::type; - -#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \ - RmsNormKernel \ - <<>>( \ - reinterpret_cast(out.data()), stride_out_batch, \ - stride_out_nhead, reinterpret_cast(input.data()), \ - stride_input_batch, stride_input_nhead, \ - reinterpret_cast(weight.data()), nhead_, dim_, eps); - - if (block_size == CUDA_BLOCK_SIZE_2048) { - LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_2048) - } else if (block_size == CUDA_BLOCK_SIZE_1024) { - LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_1024) - } else if (block_size == CUDA_BLOCK_SIZE_512) { - LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_512) - } else if (block_size == CUDA_BLOCK_SIZE_256) { - LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_256) - } else { - LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_128) - } - -#undef LAUNCH_RMS_NORM_KERNEL + DispatchFunc, ReducedFloatTypes>, + AllCudaBlockSizes>( + {static_cast(out.dtype()), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + + RmsNormKernel + <<>>( + reinterpret_cast(out.data()), stride_out_batch, + stride_out_nhead, reinterpret_cast(input.data()), + stride_input_batch, stride_input_nhead, + reinterpret_cast(weight.data()), nhead_, dim_, + eps_); }, "CudaRmsNorm::operator()"); } diff --git a/src/cuda/swiglu/kernel.h b/src/cuda/swiglu/kernel.h index 47849fe..964e8b7 100644 --- a/src/cuda/swiglu/kernel.h +++ b/src/cuda/swiglu/kernel.h @@ -50,10 +50,12 @@ class CudaSwiglu : public Swiglu { void operator()(const Tensor input, const Tensor gate, Tensor out) const override { int block_size = GetOptimalBlockSize(); - DispatchFunc( - out_type_, - [&](auto tag) { - using T = typename decltype(tag)::type; + DispatchFunc( + {static_cast(out_type_), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + auto cuda_stream = static_cast(stream_ ? stream_ : 0); dim3 blockDims( @@ -64,25 +66,11 @@ class CudaSwiglu : public Swiglu { const T* d_input = reinterpret_cast(input.data()); const T* d_gate = reinterpret_cast(gate.data()); -// Launch kernel with appropriate block size based on GPU architecture. -#define LAUNCH_SWIGLU_KERNEL(BLOCK_SIZE) \ - SwigluKernel<<>>( \ - d_out, d_input, d_gate, d_out_shape_, d_input_shape_, d_gate_shape_, \ - d_out_strides_, d_input_strides_, d_gate_strides_, output_size_, ndim_, \ - is_out_contiguous_, is_input_contiguous_, is_gate_contiguous_); - if (block_size == CUDA_BLOCK_SIZE_2048) { - LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_2048) - } else if (block_size == CUDA_BLOCK_SIZE_1024) { - LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_1024) - } else if (block_size == CUDA_BLOCK_SIZE_512) { - LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_512) - } else if (block_size == CUDA_BLOCK_SIZE_256) { - LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_256) - } else { - LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_128) - } - -#undef LAUNCH_SWIGLU_KERNEL + SwigluKernel<<>>( + d_out, d_input, d_gate, d_out_shape_, d_input_shape_, + d_gate_shape_, d_out_strides_, d_input_strides_, d_gate_strides_, + output_size_, ndim_, is_out_contiguous_, is_input_contiguous_, + is_gate_contiguous_); }, "CudaSwiglu::operator()"); } diff --git a/src/dispatcher.h b/src/dispatcher.h index 83b282c..4e7314c 100644 --- a/src/dispatcher.h +++ b/src/dispatcher.h @@ -302,6 +302,16 @@ auto DispatchFunc(ValueType value, Functor &&func, std::forward(args)...); } +// Interface for Any `int64_t`-convertible Types +template +auto DispatchFunc(std::initializer_list keys, Functor &&func, + std::string_view context_str = "", Args &&...args) { + std::vector v_keys(keys); + return DispatchFunc(v_keys, 0, std::forward(func), + context_str, List<>{}, + std::forward(args)...); +} + } // namespace infini::ops #endif diff --git a/src/operator.h b/src/operator.h index be6fb51..1378988 100644 --- a/src/operator.h +++ b/src/operator.h @@ -103,10 +103,10 @@ class Operator : public OperatorBase { DispatchFunc( tensor.device().type(), [&](auto tag) { - constexpr Device::Type dev = decltype(tag)::value; - if constexpr (std::is_constructible_v, + constexpr Device::Type kDev = decltype(tag)::value; + if constexpr (std::is_constructible_v, const Tensor&, Args...>) { - op_ptr = std::make_unique>( + op_ptr = std::make_unique>( tensor, std::forward(args)...); } else { assert(false && "operator is not implemented for this device");