From 0e5c2eb65e9c395cdb2ce11b4d360ef191130c10 Mon Sep 17 00:00:00 2001 From: flashzxi Date: Wed, 25 Feb 2026 00:13:11 +0800 Subject: [PATCH 1/5] cache --- 03_nf4_dequant/CMakeLists.txt | 28 ++++++ 03_nf4_dequant/src/main.cpp | 3 + 03_nf4_dequant/src/nf4_dequant.cu | 152 ++++++++++++++++++++++++++++++ 3 files changed, 183 insertions(+) create mode 100644 03_nf4_dequant/CMakeLists.txt create mode 100644 03_nf4_dequant/src/main.cpp create mode 100644 03_nf4_dequant/src/nf4_dequant.cu diff --git a/03_nf4_dequant/CMakeLists.txt b/03_nf4_dequant/CMakeLists.txt new file mode 100644 index 0000000..47f7dd1 --- /dev/null +++ b/03_nf4_dequant/CMakeLists.txt @@ -0,0 +1,28 @@ +cmake_minimum_required(VERSION 3.18) + +project(cuda_demo LANGUAGES CXX CUDA) + +# C++ / CUDA 标准 +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +# CUDA 架构 +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES 86) +endif() + +add_executable(nf4_dequant + src/nf4_dequant.cu + src/main.cpp +) +target_include_directories(nf4_dequant PRIVATE + ${CMAKE_SOURCE_DIR}/third_party/cutlass/include + ${CMAKE_SOURCE_DIR}/third_party/cutlass/tools/util/include +) + +set_target_properties(nf4_dequant PROPERTIES + CUDA_SEPARABLE_COMPILATION ON +) \ No newline at end of file diff --git a/03_nf4_dequant/src/main.cpp b/03_nf4_dequant/src/main.cpp new file mode 100644 index 0000000..1318c41 --- /dev/null +++ b/03_nf4_dequant/src/main.cpp @@ -0,0 +1,3 @@ +// +// Created by flashzxi on 2/24/26. +// diff --git a/03_nf4_dequant/src/nf4_dequant.cu b/03_nf4_dequant/src/nf4_dequant.cu new file mode 100644 index 0000000..3d97e17 --- /dev/null +++ b/03_nf4_dequant/src/nf4_dequant.cu @@ -0,0 +1,152 @@ +// +// Created by flashzxi on 2/24/26. +// +#include +#include +#include "cutlass/core_io.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/numeric_types.h" +#include +#include +#define ENABLE_DOUBLE_BUFFER + +#define INT4(value) (reinterpret_cast(&(value))[0]) +#define FLOAT4(value) (reinterpret_cast(&(value))[0]) +#define HALF2(value) (reinterpret_cast(&(value))[0]) +#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) +#define LDST64BITS(value) (reinterpret_cast(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) + +struct QuantState { + // header + int num_rows; + int num_cols; + int block_size; + int group_size; + + // data + uint8_t* packed_weights; // 每字节存两个 4-bit 索引 + uint8_t* absmax_q; + __half* absmax2; + __half code2[256]; // 二级码表 + float offset; + + // runtime param + std::string compute_type; + std::string target_gpu; + + int num_elements; + int num_blocks; + int num_groups; + + int packed_weights_len_in_bytes; + int absmax_q_len_in_bytes; + int absmax2_len_in_bytes; + + // 输出位置 + uint8_t *output; + + void calculate_params() { + num_elements = num_rows * num_cols; + group_size = block_size; + num_blocks = (num_elements + block_size - 1) / block_size; + num_groups = (num_blocks + group_size - 1) / group_size; + + packed_weights_len_in_bytes = (num_elements + 1) / 2; + absmax_q_len_in_bytes = num_blocks; + absmax2_len_in_bytes = 2 * num_groups; + } +}; + +// code2 为 256 * f16 +// 每个线程load 2 个,需要128个线程, 故设置一个block 128个线程,每个线程处理N个计算 +// 总计处理128 * N个数据, N 是2的幂 且不小于8 +// 结尾不够需要padding +template +__global__ void dequant_nf4_scale_f16xN_kernel(uint8_t* scale_q, FP_T* code2, FP_T* absmax2, int num_blocks, int group_size, FP_T* output) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int lane_id = threadIdx.x; + + // load code2 + __shared__ float shm_code2[128]; + shm_code2[lane_id] = code2[lane_id]; + + // 一次处理8个数据 + constexpr int loop_times = N / 8; + int g_scale_q_offset_base = blockIdx.x * 128 * N; + + // 使用double buffer需要使用shared_memory + // 不使用double buffer 可以直接再寄存器上暂存数据 +#ifdef ENABLE_DOUBLE_BUFFER + auto block = cooperative_groups::this_thread_block(); + + constexpr int STAGES = 2; // 双缓冲 + __shared__ cuda::pipeline_shared_state ps; + auto pipe = cuda::make_pipeline(block, &ps); + + // 一次读取8个数据, 双buffer + __shared__ uint8_t fragment[2][128][8]; + // 取第0块 + int scale_offset = g_scale_q_offset_base + lane_id * 8; + if (scale_offset + 8 <= num_blocks) { + pipe.producer_acquire(); + cuda::memcpy_async(block, fragment[0][lane_id], scale_q + scale_offset, cuda::aligned_size_t<16>(16), pipe); + pipe.producer_commit(); + } else if (scale_offset < num_blocks) { + // 尾块处理:退化为标量copy + for (int i = 0; i < 8 && scale_offset + i < num_blocks; ++i) { + fragment[0][lane_id][i] = scale_q[scale_offset + i]; + } + } + +#pragma unroll + for (int i = 0; i < loop_times; ++i) { + + } +#else + uint8_t fragment[8]; + FP_T cache_res[8]; +#pragma unroll + for (int i = 0; i < loop_times; ++i) { + int scale_offset = g_scale_q_offset_base + i * 128 * 8 + lane_id * 8; + FP_T scale2 = absmax2[i * 128 * 8 + lane_id * 8 / group_size]; + if (scale_offset + 7 < num_blocks) { + LDST64BITS(fragment[0]) = LDST64BITS(*(scale_q + scale_offset)); +#pragma unroll + for (int j = 0; j < 8; ++j) { + cache_res[j] = ((FP_T*) shm_code2)[fragment[i]] * scale2; + } + LDST128BITS(output[scale_offset]) = LDST128BITS(cache_res[0]); + } else if (scale_offset < num_blocks) { + // 不够一组,退化为每个元素load + int remains = num_blocks - scale_offset; + for (int j = 0; j < remains; ++j) { + fragment[j] = (scale_q + scale_offset)[j]; + cache_res[j] = ((FP_T*) shm_code2)[fragment[i]] * scale2; + output[scale_offset + j] = cache_res[j]; + } + } + } +#endif +} + +void nf4_dequant(const QuantState& quant_state) { + // 解码scale + cutlass::HostTensor scale_q(cutlass::layout::PitchLinearCoord(quant_state.num_blocks, 1)); + cutlass::HostTensor code2(cutlass::layout::PitchLinearCoord(256, 1)); + cutlass::HostTensor absmax2(cutlass::layout::PitchLinearCoord(quant_state.num_groups, 1)); + + memcpy(scale_q.host_data(), quant_state.absmax_q, quant_state.absmax_q_len_in_bytes); + memcpy(code2.host_data(), quant_state.code2, 256 * 2); + memcpy(absmax2.host_data(), quant_state.absmax2, quant_state.absmax2_len_in_bytes); + + scale_q.sync_device(); + code2.sync_device(); + absmax2.sync_device(); + + constexpr int dequant_scale_per_thread_calc = 8; + dim3 dequant_scale_block_dim(128); + dim3 dequant_scale_grid_dim((quant_state.block_size + 128 * dequant_scale_per_thread_calc - 1) / 128 * dequant_scale_per_thread_calc); + // 解码权重 +} \ No newline at end of file From aa886fc8c40a8ff3639376024978803c368ace40 Mon Sep 17 00:00:00 2001 From: flashzxi Date: Thu, 26 Feb 2026 17:44:34 +0800 Subject: [PATCH 2/5] =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=8D=95=E6=A0=B8bf4=20?= =?UTF-8?q?=E5=8F=8D=E9=87=8F=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 03_nf4_dequant/CMakeLists.txt | 28 -- 03_nf4_dequant/flashzxi/CMakeLists.txt | 36 ++ 03_nf4_dequant/flashzxi/README.md | 0 03_nf4_dequant/flashzxi/include/common.cuh | 127 ++++++ 03_nf4_dequant/flashzxi/include/nf4_dequant.h | 10 + 03_nf4_dequant/flashzxi/include/quant_state.h | 269 ++++++++++++ 03_nf4_dequant/flashzxi/src/main.cu | 45 ++ .../flashzxi/src/nf4_dequant_naive.cu | 156 +++++++ .../flashzxi/src/nf4_dequant_warpN.cu | 387 ++++++++++++++++++ .../test/conf/blocksize128_bf16_T4.ini | 3 + .../test/conf/blocksize128_fp16_T4.ini | 3 + .../test/conf/blocksize64_bf16_T4.ini | 3 + .../test/conf/blocksize64_fp16_T4.ini | 3 + 03_nf4_dequant/flashzxi/test/data/baseline.py | 146 +++++++ 03_nf4_dequant/src/main.cpp | 3 - 03_nf4_dequant/src/nf4_dequant.cu | 152 ------- 16 files changed, 1188 insertions(+), 183 deletions(-) delete mode 100644 03_nf4_dequant/CMakeLists.txt create mode 100644 03_nf4_dequant/flashzxi/CMakeLists.txt create mode 100644 03_nf4_dequant/flashzxi/README.md create mode 100644 03_nf4_dequant/flashzxi/include/common.cuh create mode 100644 03_nf4_dequant/flashzxi/include/nf4_dequant.h create mode 100644 03_nf4_dequant/flashzxi/include/quant_state.h create mode 100644 03_nf4_dequant/flashzxi/src/main.cu create mode 100644 03_nf4_dequant/flashzxi/src/nf4_dequant_naive.cu create mode 100644 03_nf4_dequant/flashzxi/src/nf4_dequant_warpN.cu create mode 100644 03_nf4_dequant/flashzxi/test/conf/blocksize128_bf16_T4.ini create mode 100644 03_nf4_dequant/flashzxi/test/conf/blocksize128_fp16_T4.ini create mode 100644 03_nf4_dequant/flashzxi/test/conf/blocksize64_bf16_T4.ini create mode 100644 03_nf4_dequant/flashzxi/test/conf/blocksize64_fp16_T4.ini create mode 100644 03_nf4_dequant/flashzxi/test/data/baseline.py delete mode 100644 03_nf4_dequant/src/main.cpp delete mode 100644 03_nf4_dequant/src/nf4_dequant.cu diff --git a/03_nf4_dequant/CMakeLists.txt b/03_nf4_dequant/CMakeLists.txt deleted file mode 100644 index 47f7dd1..0000000 --- a/03_nf4_dequant/CMakeLists.txt +++ /dev/null @@ -1,28 +0,0 @@ -cmake_minimum_required(VERSION 3.18) - -project(cuda_demo LANGUAGES CXX CUDA) - -# C++ / CUDA 标准 -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) - -set(CMAKE_CUDA_STANDARD 17) -set(CMAKE_CUDA_STANDARD_REQUIRED ON) - -# CUDA 架构 -if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - set(CMAKE_CUDA_ARCHITECTURES 86) -endif() - -add_executable(nf4_dequant - src/nf4_dequant.cu - src/main.cpp -) -target_include_directories(nf4_dequant PRIVATE - ${CMAKE_SOURCE_DIR}/third_party/cutlass/include - ${CMAKE_SOURCE_DIR}/third_party/cutlass/tools/util/include -) - -set_target_properties(nf4_dequant PROPERTIES - CUDA_SEPARABLE_COMPILATION ON -) \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/CMakeLists.txt b/03_nf4_dequant/flashzxi/CMakeLists.txt new file mode 100644 index 0000000..bf93ab2 --- /dev/null +++ b/03_nf4_dequant/flashzxi/CMakeLists.txt @@ -0,0 +1,36 @@ +cmake_minimum_required(VERSION 3.18) + +project(nf4_dequant LANGUAGES CXX CUDA) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +set(CMAKE_CUDA_ARCHITECTURES native) + +add_executable(nf4_dequant + src/main.cu + src/nf4_dequant_naive.cu + src/nf4_dequant_warpN.cu +) + +target_include_directories(nf4_dequant PRIVATE + ${CMAKE_SOURCE_DIR}/include +) + +# 单 TU/简单工程:关闭 RDC 更利于调试 +set_target_properties(nf4_dequant PROPERTIES + CUDA_SEPARABLE_COMPILATION OFF +) + +target_compile_options(nf4_dequant PRIVATE + $<$,$>:-g -O0> + $<$,$>:-G -g -O0> + + $<$,$>:-O3> + $<$,$>:-O3> + + $<$,$>:-g -O2> + $<$,$>:-lineinfo -g -O2> +) \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/README.md b/03_nf4_dequant/flashzxi/README.md new file mode 100644 index 0000000..e69de29 diff --git a/03_nf4_dequant/flashzxi/include/common.cuh b/03_nf4_dequant/flashzxi/include/common.cuh new file mode 100644 index 0000000..87eb1d2 --- /dev/null +++ b/03_nf4_dequant/flashzxi/include/common.cuh @@ -0,0 +1,127 @@ +// +// Created by core_dump on 2026/2/25. +// + +#pragma once + +#include +#include +#include +#include +#include + +__host__ __device__ __forceinline__ +float mix_mul(float fp, __half h) { + return fp * __half2float(h); +} + +__host__ __device__ __forceinline__ +float mix_mul(float fp, __nv_bfloat16 h) { + return fp * __bfloat162float(h); +} + +__host__ __device__ __forceinline__ +float f162float(__half h) { + return __half2float(h); +} + +__host__ __device__ __forceinline__ +float f162float(__nv_bfloat16 h) { + return __bfloat162float(h); +} + + +#define CUDA_CHECK(call) \ +{ \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ \ + << " - " << cudaGetErrorString(err) << "\n"; \ + std::exit(-1); \ + } \ +} + +class Timer { +public: + using clock = std::chrono::high_resolution_clock; + + Timer() : running_(false), elapsed_ms_(0.0) {} + + void tic() { + start_ = clock::now(); + running_ = true; + } + + double toc() { + if (!running_) { + return elapsed_ms_; + } + auto end = clock::now(); + elapsed_ms_ = std::chrono::duration(end - start_).count(); + running_ = false; + return elapsed_ms_; + } + + double elapsed() const { + if (!running_) { + return elapsed_ms_; + } + auto now = clock::now(); + return std::chrono::duration(now - start_).count(); + } + + void reset() { + running_ = false; + elapsed_ms_ = 0.0; + } + +private: + clock::time_point start_; + bool running_; + double elapsed_ms_; +}; + +class Tracer { +public: + Tracer() {} + + void start() { + timer_.reset(); + timer_.tic(); + } + + void stop() { + total_elapsed_ms_ += timer_.toc(); + } + + Tracer& memcpy_accumulate(uint64_t cpy_size_in_byte) { + total_data_cpy_in_bytes_ += cpy_size_in_byte; + return *this; + } + + double bandwidth_bytes_per_s() const { + if (total_elapsed_ms_ <= 0.0) { + return 0.0; + } + return static_cast(total_data_cpy_in_bytes_) * 1000.0 / total_elapsed_ms_; + } + + double bandwidth_gib_per_s() const { + if (total_elapsed_ms_ <= 0.0) { + return 0.0; + } + constexpr double kBytesPerGiB = 1024.0 * 1024.0 * 1024.0; + return static_cast(total_data_cpy_in_bytes_) * 1000.0 / total_elapsed_ms_ / kBytesPerGiB; + } + + void print(std::ostream& os = std::cout) const { + os << "elapsed: " << total_elapsed_ms_ << " ms, " + << "effective bandwidth: " << bandwidth_gib_per_s() << " GiB/s\n"; + } + +private: + Timer timer_; + + uint64_t total_data_cpy_in_bytes_ = 0; + double total_elapsed_ms_; +}; diff --git a/03_nf4_dequant/flashzxi/include/nf4_dequant.h b/03_nf4_dequant/flashzxi/include/nf4_dequant.h new file mode 100644 index 0000000..1c8265c --- /dev/null +++ b/03_nf4_dequant/flashzxi/include/nf4_dequant.h @@ -0,0 +1,10 @@ +// +// Created by core_dump on 2/25/26. +// + +#pragma once + +#include "quant_state.h" +void nf4_dequant_naive(const QuantState& quant_state, __half* output); +void nf4_dequant_warp8_batch32_two_phase(const QuantState& quant_state, __half* output); +void nf4_dequant_warp8_batch32_one_phase(const QuantState& quant_state, __half* output); \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/include/quant_state.h b/03_nf4_dequant/flashzxi/include/quant_state.h new file mode 100644 index 0000000..981cc07 --- /dev/null +++ b/03_nf4_dequant/flashzxi/include/quant_state.h @@ -0,0 +1,269 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct QuantState { + // header + int num_rows = 0; + int num_cols = 0; + int block_size = 0; + int group_size = 256; // baseline给的是256 + + // data (host) + uint8_t* packed_weights = nullptr; // 每字节存两个 4-bit 索引 + uint8_t* absmax_q = nullptr; + __half* absmax2 = nullptr; + __half code2[256]{}; + float offset = 0.f; + + // runtime param + std::string compute_type; + std::string target_gpu; + + int num_elements = 0; + int num_blocks = 0; + int num_groups = 0; + + __half* ref_result = nullptr; + + int packed_weights_len_in_bytes = 0; + int absmax_q_len_in_bytes = 0; + int absmax2_len_in_bytes = 0; + + void calculate_params() { + num_elements = num_rows * num_cols; + // group_size = 256; + num_blocks = (num_elements + block_size - 1) / block_size; + num_groups = (num_blocks + group_size - 1) / group_size; + + packed_weights_len_in_bytes = (num_elements + 1) / 2; + absmax_q_len_in_bytes = num_blocks; + absmax2_len_in_bytes = 2 * num_groups; // fp16 bytes + } + + void print() { + std::cout << "[header]" << std::endl; + std::cout << "num_rows: " << num_rows << std::endl; + std::cout << "num_cols: " << num_cols << std::endl; + std::cout << "blocksize: " << block_size << std::endl; + + std::cout << std::endl; + std::cout << "[data]" << std::endl; + std::cout << "packed_weights: " << std::endl; + int print_cnt = 0; + for (int i = 0; i < num_elements; i += 2) { + uint8_t v = packed_weights[i / 2]; + int lower = v & 0xF; + int upper = v >> 4; + std::cout << upper << "\t"; + print_cnt ++; + if (print_cnt == num_cols) { + std::cout << std::endl; + print_cnt = 0; + } + std::cout << lower << "\t"; + print_cnt ++; + if (print_cnt == num_cols) { + std::cout << std::endl; + print_cnt = 0; + } + } + std::cout << "absmax_q:" << std::endl; + for (int i = 0; i < num_blocks; ++i) { + std::cout << (int)absmax_q[i] << " "; + } + std::cout << std::endl; + + std::cout << "absmax2:" << std::endl; + for (int i = 0; i < num_groups; ++i) { + std::cout << __half2float(absmax2[i]) << " "; + } + std::cout << std::endl; + + std::cout << "code2: " << std::endl; + for (int i = 0; i < 256; ++i) { + std::cout << __half2float(code2[i]) << " "; + } + std::cout << std::endl; + + std::cout << "offset: " << offset << std::endl; + + } +}; +// --------- helpers: streaming parse, no full-file scanning ---------- + +static void expect_text(std::istream& is, const char* s) { + for (const char* p = s; *p; ++p) { + char c; + if (!is.get(c)) { + throw std::runtime_error(std::string("Unexpected EOF while expecting: ") + s); + } + if (c != *p) { + std::string msg = "Tag mismatch. Expect: "; + msg += s; + msg += " (got different byte)"; + throw std::runtime_error(msg); + } + } +} + +template +static T read_pod(std::istream& is) { + T v{}; + if (!is.read(reinterpret_cast(&v), sizeof(T))) { + throw std::runtime_error("Failed to read POD bytes"); + } + return v; // 假设小端;你写文件也用小端 pack +} + +static void read_bytes(std::istream& is, void* dst, size_t n) { + if (n == 0) return; + if (!is.read(reinterpret_cast(dst), static_cast(n))) { + throw std::runtime_error("Failed to read raw bytes"); + } +} + +static std::string trim_copy(std::string s) { + auto not_space = [](unsigned char ch){ return !std::isspace(ch); }; + s.erase(s.begin(), std::find_if(s.begin(), s.end(), not_space)); + s.erase(std::find_if(s.rbegin(), s.rend(), not_space).base(), s.end()); + return s; +} + +static std::string strip_quotes(std::string s) { + s = trim_copy(std::move(s)); + if (s.size() >= 2) { + char a = s.front(), b = s.back(); + if ((a == '"' && b == '"') || (a == '\'' && b == '\'')) { + return s.substr(1, s.size() - 2); + } + } + return s; +} + +// input_data: w_nf4.bin +// input_conf: 目前不用(保留接口) +static QuantState parse_quant_state(const std::string& input_data, + const std::string& input_conf, + const std::string& ref_result = "") { + + std::ifstream is(input_data, std::ios::binary); + if (!is) { + throw std::runtime_error("Failed to open file: " + input_data); + } + + QuantState st; + + // 你的文件格式(标签文本 + 紧跟二进制)必须严格一致: + // [header]\n + // num_rows: \n + // num_cols: \n + // blocksize: \n + // + // [data]\n + // packed_weights: \n + // absmax_q: \n + // absmax2: \n + // code2: \n + // offset: \n + + expect_text(is, "[header]\n"); + + expect_text(is, "num_rows: "); + int64_t num_rows64 = read_pod(is); + + expect_text(is, "\nnum_cols: "); + int64_t num_cols64 = read_pod(is); + + expect_text(is, "\nblocksize: "); + int32_t blocksize32 = read_pod(is); + + // 注意:QuantState 里用 int,正常矩阵规模不会溢出 + st.num_rows = static_cast(num_rows64); + st.num_cols = static_cast(num_cols64); + st.block_size = static_cast(blocksize32); + + st.calculate_params(); + + // header 后你写了 "\n\n[data]\n" + expect_text(is, "\n\n[data]\n"); + + expect_text(is, "packed_weights: "); + st.packed_weights = new uint8_t[st.packed_weights_len_in_bytes]; + read_bytes(is, st.packed_weights, static_cast(st.packed_weights_len_in_bytes)); + + expect_text(is, "\nabsmax_q: "); + st.absmax_q = new uint8_t[st.absmax_q_len_in_bytes]; + read_bytes(is, st.absmax_q, static_cast(st.absmax_q_len_in_bytes)); + + expect_text(is, "\nabsmax2: "); + st.absmax2 = new __half[st.num_groups]; + read_bytes(is, st.absmax2, static_cast(st.absmax2_len_in_bytes)); + + expect_text(is, "\ncode2: "); + read_bytes(is, st.code2, sizeof(__half) * 256); + + expect_text(is, "\noffset: "); + st.offset = read_pod(is); + + + std::ifstream i_conf(input_conf); + if (!i_conf) { + throw std::runtime_error("Failed to open conf file: " + input_conf); + } + + std::string line; + + while (std::getline(i_conf, line)) { + if (!line.empty() && line.back() == '\r') line.pop_back(); // 兼容 CRLF + + // 去掉注释:支持 # 和 // + auto cut_comment = [&](const std::string& marker) { + auto pos = line.find(marker); + if (pos != std::string::npos) line = line.substr(0, pos); + }; + cut_comment("#"); + cut_comment("//"); + + line = trim_copy(line); + if (line.empty()) continue; + + auto eq = line.find('='); + if (eq == std::string::npos) continue; + + std::string key = trim_copy(line.substr(0, eq)); + std::string val = trim_copy(line.substr(eq + 1)); + + if (key == "blocksize") { + int bs = std::stoi(val); + st.block_size = bs; + } else if (key == "compute_type") { + st.compute_type = strip_quotes(val); + } else if (key == "target_gpu") { + st.target_gpu = strip_quotes(val); + } + } + + if (!ref_result.empty()) { + std::ifstream i_ref_res(ref_result); + if (!i_ref_res) { + throw std::runtime_error("Failed to open conf file: " + ref_result); + } + st.ref_result = new __half[st.num_elements]; + if (!i_ref_res.read(reinterpret_cast(st.ref_result), static_cast(st.num_elements * 2))) { + throw std::runtime_error("Failed to read raw bytes"); + } + } + + return st; +} + diff --git a/03_nf4_dequant/flashzxi/src/main.cu b/03_nf4_dequant/flashzxi/src/main.cu new file mode 100644 index 0000000..7611646 --- /dev/null +++ b/03_nf4_dequant/flashzxi/src/main.cu @@ -0,0 +1,45 @@ +// +// Created by flashzxi on 2/24/26. +// +#include "quant_state.h" +#include "cuda_runtime.h" +#include "nf4_dequant.h" + +// https://gxtctab8no8.feishu.cn/wiki/UoESwCDZ2iZRcLkdzjvcxTgenOb?from=from_copylink + +int main() { + int row = 10000; + int col = 10000; + std::string file_prefix = std::string("/home/core_dump/Learning-CUDA/03_nf4_dequant/flashzxi/test/data/nf4_") + std::to_string(row) + "x" + std::to_string(col) + "_fp16"; + auto conf = parse_quant_state(file_prefix + ".bin", + "/home/core_dump/Learning-CUDA/03_nf4_dequant/flashzxi/test/conf/blocksize64_fp16_T4.ini", + file_prefix + "_w_dequant.bin"); + + // conf.print(); + + // std::cout << "real absmax: "; + // for (int i = 0; i < 4; i ++) { + // int idx = conf.absmax_q[i]; + // std::cout << __half2float(conf.code2[idx] * conf.absmax2[0]) + conf.offset << " "; + // } + + std::cout << std::endl; + __half* ans = new __half[conf.num_elements]; + nf4_dequant_warp8_batch32_one_phase(conf, ans); + + float max_diff = 0.f; + for (int i = 0; i < conf.num_rows; i++) { + for (int j = 0; j < conf.num_cols; j++) { + int idx = i * conf.num_cols + j; + float a = __half2float(ans[idx]); + float b = __half2float(conf.ref_result[idx]); + float diff = fabsf(a - b); + diff /= b; + max_diff = std::max(max_diff, diff); + // std::cout << a << " "; + } + // std::cout << "\n"; + } + std::cout << "max_diff = " << max_diff << "\n"; +} + diff --git a/03_nf4_dequant/flashzxi/src/nf4_dequant_naive.cu b/03_nf4_dequant/flashzxi/src/nf4_dequant_naive.cu new file mode 100644 index 0000000..7082e56 --- /dev/null +++ b/03_nf4_dequant/flashzxi/src/nf4_dequant_naive.cu @@ -0,0 +1,156 @@ +// +// Created by core_dump on 2026/2/25. +// +#include +#include +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include +#include +#include "quant_state.h" +#include "common.cuh" +#include "nf4_dequant.h" + +template +__global__ void dequant_absmax_kernel(const uint8_t* __restrict__ absmax_q, + const FP_T* __restrict__ absmax2, + const FP_T* __restrict__ code2, // 256 + int num_blocks, + int group_size, // blocks per group + float offset, + float* __restrict__ absmax_out) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= num_blocks) return; + + int group = i / group_size; + float s2 = f162float(absmax2[group]); + float c = f162float(code2[absmax_q[i]]); + absmax_out[i] = c * s2 + offset; +} + +// 每个block 128个线程,每个线程负责2个 +template +__global__ void dequant_nf4_kernel(const uint8_t* __restrict__ packed, + const float* __restrict__ absmax, + int num_elements, + int block_size, + OUT_T* __restrict__ out) { + float kNF4[16] = { + -1.0000000f, -0.6961928f, -0.5250731f, -0.3949175f, + -0.2844414f, -0.1847734f, -0.0910500f, 0.0000000f, + 0.0795803f, 0.1609302f, 0.2461123f, 0.3379152f, + 0.4407098f, 0.5626170f, 0.7229568f, 1.0000000f + }; + int t = blockIdx.x * blockDim.x + threadIdx.x; + int elem0 = t * 2; + if (elem0 >= num_elements) return; + + uint8_t byte = packed[t]; + int lo = byte & 0x0F; + int hi = byte >> 4; + + float s0 = absmax[elem0 / block_size]; + float v0 = s0 * kNF4[hi]; + if constexpr (std::is_same_v) { + out[elem0] = __float2half(v0); + } else { + out[elem0] = __float2bfloat16(v0); + } + + int elem1 = elem0 + 1; + if (elem1 < num_elements) { + float s1 = absmax[elem1 / block_size]; + float v1 = s1 * kNF4[lo]; + if constexpr (std::is_same_v) { + out[elem1] = __float2half(v1); + } else { + out[elem1] = __float2bfloat16(v1); + } + } +} + +void nf4_dequant_naive(const QuantState& quant_state, __half* output) { + // 解码scale + uint8_t* scale_q_s; + __half* code2_s; + __half* absmax2_s; + + CUDA_CHECK(cudaMalloc(&scale_q_s, quant_state.num_blocks)); + CUDA_CHECK(cudaMalloc(&code2_s, 256 * sizeof(__half))); + CUDA_CHECK(cudaMalloc(&absmax2_s, quant_state.num_groups * sizeof(__half))); + + CUDA_CHECK(cudaMemcpy(scale_q_s, quant_state.absmax_q, quant_state.absmax_q_len_in_bytes, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(code2_s, quant_state.code2, 256 * sizeof(__half), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(absmax2_s, quant_state.absmax2, quant_state.num_groups * sizeof(__half), cudaMemcpyHostToDevice)); + + dim3 dequant_scale_block_dim(128); + dim3 dequant_scale_grid_dim((quant_state.num_blocks + 128 - 1) / 128); + // 解码权重 + float* absmax = nullptr; + size_t absmax_bytes = sizeof(float) * quant_state.num_blocks; + CUDA_CHECK(cudaMalloc(&absmax, absmax_bytes)); + + float* absmax_h = new float[quant_state.num_blocks]; + + if (quant_state.compute_type == "bf16") { + dequant_absmax_kernel<__nv_bfloat16><<>>( + scale_q_s, (__nv_bfloat16*) absmax2_s, + (__nv_bfloat16*) code2_s, quant_state.num_blocks, quant_state.group_size, quant_state.offset, absmax + ); + } else if (quant_state.compute_type == "fp16") { + dequant_absmax_kernel<__half><<>>( + scale_q_s, (__half*) absmax2_s, + (__half*) code2_s, quant_state.num_blocks, quant_state.group_size, quant_state.offset,absmax + ); + } else { + std::cerr << "Type Not Supported, only support bf16 | fp16" << std::endl; + exit(-1); + } + + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + cudaMemcpy(absmax_h, absmax, quant_state.num_blocks * sizeof(float), cudaMemcpyDeviceToHost); + for (int i = 0; i < quant_state.num_blocks; i++) { + std::cout << absmax_h[i] << " "; + } + std::cout << std::endl; + + CUDA_CHECK(cudaFree(scale_q_s)); + CUDA_CHECK(cudaFree(code2_s)); + CUDA_CHECK(cudaFree(absmax2_s)); + + uint8_t* packed_weights_s; + // output + __half* unpacked_weights_s; + + CUDA_CHECK(cudaMalloc(&packed_weights_s, quant_state.packed_weights_len_in_bytes)); + CUDA_CHECK(cudaMalloc(&unpacked_weights_s, quant_state.num_elements * sizeof(__half))); + + CUDA_CHECK(cudaMemcpy(packed_weights_s, quant_state.packed_weights, quant_state.packed_weights_len_in_bytes, cudaMemcpyHostToDevice)) + + dim3 dequant_weights_grid_dim((quant_state.packed_weights_len_in_bytes + dequant_scale_block_dim.x - 1) / dequant_scale_block_dim.x); + + if (quant_state.compute_type == "bf16") { + dequant_nf4_kernel<__nv_bfloat16><<>> ( + packed_weights_s, absmax, quant_state.num_elements, + quant_state.block_size, (__nv_bfloat16*) unpacked_weights_s + ); + } else if (quant_state.compute_type == "fp16") { + dequant_nf4_kernel<__half><<>> ( + packed_weights_s, absmax, quant_state.num_elements, + quant_state.block_size, (__half*) unpacked_weights_s + ); + } else { + std::cerr << "Type Not Supported, only support bf16 | fp16" << std::endl; + exit(-1); + } + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemcpy(output, unpacked_weights_s, quant_state.num_elements * sizeof(__half), cudaMemcpyDeviceToHost)); + + CUDA_CHECK(cudaFree(packed_weights_s)); + CUDA_CHECK(cudaFree(absmax)); + CUDA_CHECK(cudaFree(unpacked_weights_s)); +} \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/src/nf4_dequant_warpN.cu b/03_nf4_dequant/flashzxi/src/nf4_dequant_warpN.cu new file mode 100644 index 0000000..7bce54a --- /dev/null +++ b/03_nf4_dequant/flashzxi/src/nf4_dequant_warpN.cu @@ -0,0 +1,387 @@ +// +// Created by flashzxi on 2/24/26. +// +#include +#include +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include +#include +#include "quant_state.h" +#include "nf4_dequant.h" +#include "common.cuh" + +#define LDST32BITS(value) (reinterpret_cast(&(value))[0]) +#define LDST64BITS(value) (reinterpret_cast(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) + +// code2 为 256 * f16 +// 每个线程load 2 个,需要128个线程, 故设置一个block 128个线程,每个线程处理N个计算 +// 总计处理128 * N个数据, N 是2的幂 且不小于8 +// 结尾不够需要padding +template +__global__ void dequant_nf4_scale_warp8_batchN_kernel( + uint8_t* scale_q, + HFP_T* code2, + HFP_T* absmax2, + int num_blocks, + int group_size, + float offset, + float* output) { + int lane_id = threadIdx.x; + + // load code2 + __shared__ float shm_code2_float[128]; + + LDST32BITS(shm_code2_float[lane_id]) = LDST32BITS(code2[2 * lane_id]); + HFP_T* shm_code2 = (HFP_T *) shm_code2_float; + __syncthreads(); + + // 一次处理8个数据 + constexpr int loop_times = N / 8; + int g_scale_q_offset_base = blockIdx.x * 128 * N; + + alignas(16) uint8_t fragment[8]; + alignas(16) float cache_res[8]; +#pragma unroll + for (int i = 0; i < loop_times; ++i) { + int scale_offset = g_scale_q_offset_base + i * 128 * 8 + lane_id * 8; + HFP_T scale2 = absmax2[scale_offset / group_size]; + + if (scale_offset + 7 < num_blocks) { + LDST64BITS(fragment[0]) = LDST64BITS(*(scale_q + scale_offset)); +#pragma unroll + for (int j = 0; j < 8; ++j) { + cache_res[j] = f162float( shm_code2[fragment[j]] * scale2 ) + offset; + } + LDST128BITS(output[scale_offset]) = LDST128BITS(cache_res[0]); + LDST128BITS(output[scale_offset + 4]) = LDST128BITS(cache_res[4]); + } else if (scale_offset < num_blocks) { + // 不够一组,退化为每个元素load + int remains = num_blocks - scale_offset; + for (int j = 0; j < remains; ++j) { + fragment[j] = (scale_q + scale_offset)[j]; + cache_res[j] = f162float(shm_code2[fragment[j]] * scale2) + offset; + output[scale_offset + j] = cache_res[j]; + } + } + } +} + +// 一个block 128个线程 +// 每个线程负责N个, 每个block 负责 128 * N 个数据的解码 +template +__global__ void dequant_nf4_elements_warp8_batchN_kernel(uint8_t* packed_weights, float* absmax, int num_elements, int block_size, HFP_T* output) { + float kNF4[16] = { + -1.0000000f, -0.6961928f, -0.5250731f, -0.3949175f, + -0.2844414f, -0.1847734f, -0.0910500f, 0.0000000f, + 0.0795803f, 0.1609302f, 0.2461123f, 0.3379152f, + 0.4407098f, 0.5626170f, 0.7229568f, 1.0000000f + }; + uint8_t* packed_weights_end = packed_weights + (num_elements + 1) / 2; + + int bidx = blockIdx.x; + int lane_id = threadIdx.x; + + int block_offset = bidx * 128 * N; + + // 每次处理8个,32bits + alignas(16) uint8_t f_packed_weights[4]; + alignas(16) HFP_T cache_res[8]; + constexpr int loop_times = N / 8; +#pragma unroll + for (int i = 0; i < loop_times; ++i) { + int g_packed_weights_offset = block_offset + 8 * 128 * i + 8 * lane_id; + float scale = absmax[g_packed_weights_offset / block_size]; + if (packed_weights + g_packed_weights_offset / 2 + 4 < packed_weights_end) { + LDST32BITS(f_packed_weights[0]) = LDST32BITS(packed_weights[g_packed_weights_offset / 2]); +#pragma unroll + for (int j = 0; j < 4; ++j) { + uint8_t lower = f_packed_weights[j] & 0xF; + uint8_t upper = f_packed_weights[j] >> 4; + if constexpr (std::is_same_v) { + cache_res[2 * j] = __float2half(scale * kNF4[upper]); + cache_res[2 * j + 1] = __float2half(scale * kNF4[lower]); + } else if constexpr (std::is_same_v) { + cache_res[2 * j] = __float2bfloat16(scale * kNF4[upper]); + cache_res[2 * j + 1] = __float2bfloat16(scale * kNF4[lower]); + } + } + LDST128BITS(output[g_packed_weights_offset]) = LDST128BITS(cache_res[0]); + } else if (packed_weights + g_packed_weights_offset / 2 < packed_weights_end) { + int remains = num_elements - g_packed_weights_offset; + for (int j = 0; j < (remains + 1) / 2; ++j) { + f_packed_weights[0] = packed_weights[g_packed_weights_offset / 2 + j]; + uint8_t lower = f_packed_weights[0] & 0xF; + uint8_t upper = f_packed_weights[0] >> 4; + if constexpr (std::is_same_v) { + cache_res[0] = __float2half(scale * kNF4[upper]); + cache_res[1] = __float2half(scale * kNF4[lower]); + } else if constexpr (std::is_same_v) { + cache_res[0] = __float2bfloat16(scale * kNF4[upper]); + cache_res[1] = __float2bfloat16(scale * kNF4[lower]); + } + if (g_packed_weights_offset + 2 * j >= num_elements) { + // 只需要写回第一个 + output[g_packed_weights_offset + 2 * j] = cache_res[0]; + } else { + // 两个打包写回 + LDST32BITS(output[g_packed_weights_offset + 2 * j]) = LDST32BITS(cache_res[0]); + } + } + } + } +} + +// 一个block 128个线程 +// 每个线程负责N个, 每个block 负责 128 * N 个数据的解码 +template +__global__ void dequant_nf4_elements_one_phase_warp8_batchN_kernel( + uint8_t* packed_weights, + uint8_t* absmax_q, + int num_elements, + HFP_T* absmax2, + HFP_T* code2, + int block_size, + int group_size, + float offset, + HFP_T* output) { + float kNF4[16] = { + -1.0000000f, -0.6961928f, -0.5250731f, -0.3949175f, + -0.2844414f, -0.1847734f, -0.0910500f, 0.0000000f, + 0.0795803f, 0.1609302f, 0.2461123f, 0.3379152f, + 0.4407098f, 0.5626170f, 0.7229568f, 1.0000000f + }; + uint8_t* packed_weights_end = packed_weights + (num_elements + 1) / 2; + + int bidx = blockIdx.x; + int lane_id = threadIdx.x; + + // load code2 + __shared__ float shm_code2_float[128]; + + LDST32BITS(shm_code2_float[lane_id]) = LDST32BITS(code2[2 * lane_id]); + HFP_T* shm_code2 = (HFP_T *) shm_code2_float; + __syncthreads(); + + int block_offset = bidx * 128 * N; + + // 每次处理8个,32bits + alignas(16) uint8_t f_packed_weights[4]; + alignas(16) HFP_T cache_res[8]; + constexpr int loop_times = N / 8; +#pragma unroll + for (int i = 0; i < loop_times; ++i) { + int g_packed_weights_offset = block_offset + 8 * 128 * i + 8 * lane_id; + int block_idx = g_packed_weights_offset / block_size; + int group_idx = block_idx / group_size; + float scale = f162float(shm_code2[absmax_q[block_idx]] * absmax2[group_idx]) + offset; + if (packed_weights + g_packed_weights_offset / 2 + 4 < packed_weights_end) { + LDST32BITS(f_packed_weights[0]) = LDST32BITS(packed_weights[g_packed_weights_offset / 2]); +#pragma unroll + for (int j = 0; j < 4; ++j) { + uint8_t lower = f_packed_weights[j] & 0xF; + uint8_t upper = f_packed_weights[j] >> 4; + if constexpr (std::is_same_v) { + cache_res[2 * j] = __float2half(scale * kNF4[upper]); + cache_res[2 * j + 1] = __float2half(scale * kNF4[lower]); + } else if constexpr (std::is_same_v) { + cache_res[2 * j] = __float2bfloat16(scale * kNF4[upper]); + cache_res[2 * j + 1] = __float2bfloat16(scale * kNF4[lower]); + } + } + LDST128BITS(output[g_packed_weights_offset]) = LDST128BITS(cache_res[0]); + } else if (packed_weights + g_packed_weights_offset / 2 < packed_weights_end) { + int remains = num_elements - g_packed_weights_offset; + for (int j = 0; j < (remains + 1) / 2; ++j) { + f_packed_weights[0] = packed_weights[g_packed_weights_offset / 2 + j]; + uint8_t lower = f_packed_weights[0] & 0xF; + uint8_t upper = f_packed_weights[0] >> 4; + if constexpr (std::is_same_v) { + cache_res[0] = __float2half(scale * kNF4[upper]); + cache_res[1] = __float2half(scale * kNF4[lower]); + } else if constexpr (std::is_same_v) { + cache_res[0] = __float2bfloat16(scale * kNF4[upper]); + cache_res[1] = __float2bfloat16(scale * kNF4[lower]); + } + if (g_packed_weights_offset + 2 * j >= num_elements) { + // 只需要写回第一个 + output[g_packed_weights_offset + 2 * j] = cache_res[0]; + } else { + // 两个打包写回 + LDST32BITS(output[g_packed_weights_offset + 2 * j]) = LDST32BITS(cache_res[0]); + } + } + } + } +} + +void nf4_dequant_warp8_batch32_two_phase(const QuantState& quant_state, __half* output) { + constexpr int PROCESS_SIZE = 32; + + // 解码scale + uint8_t* scale_q_s; + __half* code2_s; + __half* absmax2_s; + + CUDA_CHECK(cudaMalloc(&scale_q_s, quant_state.num_blocks)); + CUDA_CHECK(cudaMalloc(&code2_s, 256 * sizeof(__half))); + CUDA_CHECK(cudaMalloc(&absmax2_s, quant_state.num_groups * sizeof(__half))); + + CUDA_CHECK(cudaMemcpy(scale_q_s, quant_state.absmax_q, quant_state.absmax_q_len_in_bytes, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(code2_s, quant_state.code2, 256 * sizeof(__half), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(absmax2_s, quant_state.absmax2, quant_state.num_groups * sizeof(__half), cudaMemcpyHostToDevice)); + + Tracer tracer; + tracer.memcpy_accumulate(quant_state.num_blocks) + .memcpy_accumulate(256 * sizeof(__half)) + .memcpy_accumulate(quant_state.num_groups * sizeof(__half)) + .memcpy_accumulate(quant_state.packed_weights_len_in_bytes) + .memcpy_accumulate(quant_state.num_elements * sizeof(__half)); + + dim3 dequant_scale_block_dim(128); + dim3 dequant_scale_grid_dim((quant_state.num_blocks + dequant_scale_block_dim.x * PROCESS_SIZE - 1) / (dequant_scale_block_dim.x * PROCESS_SIZE)); + // 解码权重 + float* absmax = nullptr; + size_t absmax_bytes = sizeof(float) * quant_state.num_blocks; + CUDA_CHECK(cudaMalloc(&absmax, absmax_bytes)); + + float* absmax_h = new float[quant_state.num_blocks]; + + tracer.start(); + if (quant_state.compute_type == "bf16") { + dequant_nf4_scale_warp8_batchN_kernel<__nv_bfloat16, PROCESS_SIZE><<>>( + scale_q_s, (__nv_bfloat16*) code2_s, + (__nv_bfloat16*) absmax2_s, quant_state.num_blocks, quant_state.group_size, quant_state.offset, absmax + ); + } else if (quant_state.compute_type == "fp16") { + dequant_nf4_scale_warp8_batchN_kernel<__half, PROCESS_SIZE><<>>( + scale_q_s, (__half*) code2_s, + (__half*) absmax2_s, quant_state.num_blocks, quant_state.group_size, quant_state.offset,absmax + ); + } else { + std::cerr << "Type Not Supported, only support bf16 | fp16" << std::endl; + exit(-1); + } + tracer.stop(); + + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + cudaMemcpy(absmax_h, absmax, quant_state.num_blocks * sizeof(float), cudaMemcpyDeviceToHost); + + CUDA_CHECK(cudaFree(scale_q_s)); + CUDA_CHECK(cudaFree(code2_s)); + CUDA_CHECK(cudaFree(absmax2_s)); + + uint8_t* packed_weights_s; + // output + __half* unpacked_weights_s; + + CUDA_CHECK(cudaMalloc(&packed_weights_s, quant_state.packed_weights_len_in_bytes)); + CUDA_CHECK(cudaMalloc(&unpacked_weights_s, quant_state.num_elements * sizeof(__half))); + CUDA_CHECK(cudaMemcpy(packed_weights_s, quant_state.packed_weights, quant_state.packed_weights_len_in_bytes, cudaMemcpyHostToDevice)) + + dim3 dequant_weights_grid_dim((quant_state.num_elements + dequant_scale_block_dim.x * PROCESS_SIZE - 1) / (dequant_scale_block_dim.x * PROCESS_SIZE)); + + tracer.start(); + if (quant_state.compute_type == "bf16") { + dequant_nf4_elements_warp8_batchN_kernel<__nv_bfloat16, PROCESS_SIZE><<>> ( + packed_weights_s, absmax, quant_state.num_elements, + quant_state.block_size, (__nv_bfloat16*) unpacked_weights_s + ); + } else if (quant_state.compute_type == "fp16") { + dequant_nf4_elements_warp8_batchN_kernel<__half, PROCESS_SIZE><<>> ( + packed_weights_s, absmax, quant_state.num_elements, + quant_state.block_size, (__half*) unpacked_weights_s + ); + } else { + std::cerr << "Type Not Supported, only support bf16 | fp16" << std::endl; + exit(-1); + } + tracer.stop(); + tracer.print(); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemcpy(output, unpacked_weights_s, quant_state.num_elements * sizeof(__half), cudaMemcpyDeviceToHost)); + + CUDA_CHECK(cudaFree(packed_weights_s)); + CUDA_CHECK(cudaFree(absmax)); + CUDA_CHECK(cudaFree(unpacked_weights_s)); +} + +void nf4_dequant_warp8_batch32_one_phase(const QuantState& quant_state, __half* output) { + constexpr int PROCESS_SIZE = 32; + + uint8_t* absmax_q_s; + __half* code2_s; + __half* absmax2_s; + uint8_t* packed_weights_s; + // output + __half* unpacked_weights_s; + + CUDA_CHECK(cudaMalloc(&absmax_q_s, quant_state.num_blocks)); + CUDA_CHECK(cudaMalloc(&code2_s, 256 * sizeof(__half))); + CUDA_CHECK(cudaMalloc(&absmax2_s, quant_state.num_groups * sizeof(__half))); + CUDA_CHECK(cudaMalloc(&packed_weights_s, quant_state.packed_weights_len_in_bytes)); + CUDA_CHECK(cudaMalloc(&unpacked_weights_s, quant_state.num_elements * sizeof(__half))); + Tracer tracer; + tracer.memcpy_accumulate(quant_state.num_blocks) + .memcpy_accumulate(256 * sizeof(__half)) + .memcpy_accumulate(quant_state.num_groups * sizeof(__half)) + .memcpy_accumulate(quant_state.packed_weights_len_in_bytes) + .memcpy_accumulate(quant_state.num_elements * sizeof(__half)); + + CUDA_CHECK(cudaMemcpy(absmax_q_s, quant_state.absmax_q, quant_state.absmax_q_len_in_bytes, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(code2_s, quant_state.code2, 256 * sizeof(__half), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(absmax2_s, quant_state.absmax2, quant_state.num_groups * sizeof(__half), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(packed_weights_s, quant_state.packed_weights, quant_state.packed_weights_len_in_bytes, cudaMemcpyHostToDevice)) + + dim3 dequant_scale_block_dim(128); + dim3 dequant_weights_grid_dim((quant_state.num_elements + dequant_scale_block_dim.x * PROCESS_SIZE - 1) / (dequant_scale_block_dim.x * PROCESS_SIZE)); + + tracer.start(); + if (quant_state.compute_type == "bf16") { + dequant_nf4_elements_one_phase_warp8_batchN_kernel<__nv_bfloat16, PROCESS_SIZE><<>> ( + packed_weights_s, + absmax_q_s, + quant_state.num_elements, + (__nv_bfloat16*) absmax2_s, + (__nv_bfloat16*) code2_s, + quant_state.block_size, + quant_state.group_size, + quant_state.offset, + (__nv_bfloat16*) unpacked_weights_s + ); + } else if (quant_state.compute_type == "fp16") { + dequant_nf4_elements_one_phase_warp8_batchN_kernel<__half, PROCESS_SIZE><<>> ( + packed_weights_s, + absmax_q_s, + quant_state.num_elements, + (__half*) absmax2_s, + (__half*) code2_s, + quant_state.block_size, + quant_state.group_size, + quant_state.offset, + (__half*) unpacked_weights_s + ); + } else { + std::cerr << "Type Not Supported, only support bf16 | fp16" << std::endl; + exit(-1); + } + tracer.stop(); + tracer.print(); + + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemcpy(output, unpacked_weights_s, quant_state.num_elements * sizeof(__half), cudaMemcpyDeviceToHost)); + + CUDA_CHECK(cudaFree(absmax_q_s)); + CUDA_CHECK(cudaFree(code2_s)); + CUDA_CHECK(cudaFree(absmax2_s)); + CUDA_CHECK(cudaFree(packed_weights_s)); + CUDA_CHECK(cudaFree(unpacked_weights_s)); + +} \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/test/conf/blocksize128_bf16_T4.ini b/03_nf4_dequant/flashzxi/test/conf/blocksize128_bf16_T4.ini new file mode 100644 index 0000000..f8359b5 --- /dev/null +++ b/03_nf4_dequant/flashzxi/test/conf/blocksize128_bf16_T4.ini @@ -0,0 +1,3 @@ +blocksize = 128 +compute_type = "bp16" +target_gpu = "T4" \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/test/conf/blocksize128_fp16_T4.ini b/03_nf4_dequant/flashzxi/test/conf/blocksize128_fp16_T4.ini new file mode 100644 index 0000000..8fc1503 --- /dev/null +++ b/03_nf4_dequant/flashzxi/test/conf/blocksize128_fp16_T4.ini @@ -0,0 +1,3 @@ +blocksize = 128 +compute_type = "fp16" +target_gpu = "T4" \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/test/conf/blocksize64_bf16_T4.ini b/03_nf4_dequant/flashzxi/test/conf/blocksize64_bf16_T4.ini new file mode 100644 index 0000000..a7c8fa0 --- /dev/null +++ b/03_nf4_dequant/flashzxi/test/conf/blocksize64_bf16_T4.ini @@ -0,0 +1,3 @@ +blocksize = 64 +compute_type = "bf16" +target_gpu = "T4" \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/test/conf/blocksize64_fp16_T4.ini b/03_nf4_dequant/flashzxi/test/conf/blocksize64_fp16_T4.ini new file mode 100644 index 0000000..b80eab6 --- /dev/null +++ b/03_nf4_dequant/flashzxi/test/conf/blocksize64_fp16_T4.ini @@ -0,0 +1,3 @@ +blocksize = 64 +compute_type = "fp16" +target_gpu = "T4" \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/test/data/baseline.py b/03_nf4_dequant/flashzxi/test/data/baseline.py new file mode 100644 index 0000000..4ca6f79 --- /dev/null +++ b/03_nf4_dequant/flashzxi/test/data/baseline.py @@ -0,0 +1,146 @@ +import struct +import torch +import bitsandbytes.functional as F +import time + +def _dequant_bnb(qweight: torch.Tensor, qs): + """ + 兼容不同 bnb 版本的反量化入口: + 优先用 dequantize_4bit;没有的话再退到 dequantize_blockwise。 + """ + if hasattr(F, "dequantize_4bit"): + # 新版常见:直接传 quant_state + return F.dequantize_4bit(qweight, quant_state=qs) + if hasattr(F, "dequantize_blockwise"): + # 老版可能需要 absmax/code 等;但如果传 quant_state 通常也能工作 + return F.dequantize_blockwise(qweight, quant_state=qs) + raise RuntimeError("当前 bitsandbytes.functional 里找不到 dequantize_4bit / dequantize_blockwise") + +def save_nf4_tagged_binary(path: str, W: torch.Tensor, blocksize: int = 64): + """ + 写 w_nf4.bin: + [header]\n + num_rows:\n + num_cols:\n + blocksize:\n\n + [data]\n + packed_weights:\n + absmax_q:\n + absmax2:\n + code2:\n + offset:\n + """ + assert W.ndim == 2 and W.is_cuda + + num_rows, num_cols = map(int, W.shape) + num_elements = num_rows * num_cols + num_blocks = (num_elements + blocksize - 1) // blocksize + + qweight, qs = F.quantize_4bit( + W, + blocksize=blocksize, + quant_type="nf4", + compress_statistics=True, + quant_storage=torch.uint8, + ) + if not getattr(qs, "nested", False) or qs.state2 is None: + raise RuntimeError("需要 compress_statistics=True 才会有 absmax_q/absmax2/code2/offset") + + packed = qweight.detach().contiguous().view(torch.uint8).cpu() + packed_len = (num_elements + 1) // 2 + if packed.numel() != packed_len: + raise RuntimeError(f"packed_weights len mismatch: got={packed.numel()} expected={packed_len}") + + absmax_q = qs.absmax.detach().contiguous().view(torch.uint8).cpu() + if absmax_q.numel() != num_blocks: + raise RuntimeError(f"absmax_q len mismatch: got={absmax_q.numel()} expected={num_blocks}") + + absmax2 = qs.state2.absmax.detach().contiguous().cpu().to(torch.float16) + code2 = qs.state2.code.detach().contiguous().cpu().to(torch.float16) + if code2.numel() != 256: + raise RuntimeError(f"code2 len mismatch: got={code2.numel()} expected=256") + + offset = float(qs.offset) if qs.offset is not None else 0.0 + + with open(path, "wb") as f: + f.write(b"[header]\n") + f.write(b"num_rows: ") + f.write(struct.pack("\n + num_cols:\n + dtype:<1 byte tag>\n + data:\n + + dtype tag: 1 = fp16, 2 = fp32 + """ + num_rows, num_cols = shape + deq2d = deq.reshape(num_rows, num_cols).detach() + + if out_dtype == torch.float16: + tag = 1 + host = deq2d.to(torch.float16).contiguous().cpu() + elif out_dtype == torch.float32: + tag = 2 + host = deq2d.to(torch.float32).contiguous().cpu() + else: + raise ValueError("out_dtype 只支持 torch.float16 或 torch.float32") + + with open(path, "wb") as f: + f.write(host.numpy().tobytes(order="C")) + + # with open(path + ".txt", "w", encoding="utf-8") as f: + # f.write("[dequant]\n") + # f.write(f"num_rows: {int(num_rows)}\n") + # f.write(f"num_cols: {int(num_cols)}\n") + # f.write(f"dtype: {int(tag)}\n") + # f.write("data:\n") + # + # # 逐行写,空格分隔 + # # 可以按需要改格式,比如 "{:.6f}" + # for i in range(num_rows): + # row = host[i].tolist() + # f.write(" ".join(f"{v:.6f}" for v in row)) + # f.write("\n") + # # print(" ".join(f"{v:.6f}" for v in row)) + +if __name__ == "__main__": + torch.manual_seed(0) + torch.manual_seed(1234) # CPU RNG + torch.cuda.manual_seed_all(1234) # 所有 GPU RNG(单卡也可用) + row = 10000 + col = 10000 + W = torch.randn(row, col, device="cuda", dtype=torch.float16) + file_prefix = f"nf4_{row}x{col}_fp16" + qweight, qs, shape = save_nf4_tagged_binary(file_prefix + ".bin", W, blocksize=64) + + start = time.perf_counter() + deq = _dequant_bnb(qweight, qs) # bnb 反量化 + end = time.perf_counter() + elapsed_ms = (end - start) * 1000 + print(f"dequantize_4bit执行时间: {elapsed_ms:.3f} ms") + save_dequant_result(file_prefix + "_w_dequant.bin", deq, shape, out_dtype=torch.float16) diff --git a/03_nf4_dequant/src/main.cpp b/03_nf4_dequant/src/main.cpp deleted file mode 100644 index 1318c41..0000000 --- a/03_nf4_dequant/src/main.cpp +++ /dev/null @@ -1,3 +0,0 @@ -// -// Created by flashzxi on 2/24/26. -// diff --git a/03_nf4_dequant/src/nf4_dequant.cu b/03_nf4_dequant/src/nf4_dequant.cu deleted file mode 100644 index 3d97e17..0000000 --- a/03_nf4_dequant/src/nf4_dequant.cu +++ /dev/null @@ -1,152 +0,0 @@ -// -// Created by flashzxi on 2/24/26. -// -#include -#include -#include "cutlass/core_io.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/util/host_tensor.h" -#include "cutlass/numeric_types.h" -#include -#include -#define ENABLE_DOUBLE_BUFFER - -#define INT4(value) (reinterpret_cast(&(value))[0]) -#define FLOAT4(value) (reinterpret_cast(&(value))[0]) -#define HALF2(value) (reinterpret_cast(&(value))[0]) -#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) -#define LDST64BITS(value) (reinterpret_cast(&(value))[0]) -#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) - -struct QuantState { - // header - int num_rows; - int num_cols; - int block_size; - int group_size; - - // data - uint8_t* packed_weights; // 每字节存两个 4-bit 索引 - uint8_t* absmax_q; - __half* absmax2; - __half code2[256]; // 二级码表 - float offset; - - // runtime param - std::string compute_type; - std::string target_gpu; - - int num_elements; - int num_blocks; - int num_groups; - - int packed_weights_len_in_bytes; - int absmax_q_len_in_bytes; - int absmax2_len_in_bytes; - - // 输出位置 - uint8_t *output; - - void calculate_params() { - num_elements = num_rows * num_cols; - group_size = block_size; - num_blocks = (num_elements + block_size - 1) / block_size; - num_groups = (num_blocks + group_size - 1) / group_size; - - packed_weights_len_in_bytes = (num_elements + 1) / 2; - absmax_q_len_in_bytes = num_blocks; - absmax2_len_in_bytes = 2 * num_groups; - } -}; - -// code2 为 256 * f16 -// 每个线程load 2 个,需要128个线程, 故设置一个block 128个线程,每个线程处理N个计算 -// 总计处理128 * N个数据, N 是2的幂 且不小于8 -// 结尾不够需要padding -template -__global__ void dequant_nf4_scale_f16xN_kernel(uint8_t* scale_q, FP_T* code2, FP_T* absmax2, int num_blocks, int group_size, FP_T* output) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int lane_id = threadIdx.x; - - // load code2 - __shared__ float shm_code2[128]; - shm_code2[lane_id] = code2[lane_id]; - - // 一次处理8个数据 - constexpr int loop_times = N / 8; - int g_scale_q_offset_base = blockIdx.x * 128 * N; - - // 使用double buffer需要使用shared_memory - // 不使用double buffer 可以直接再寄存器上暂存数据 -#ifdef ENABLE_DOUBLE_BUFFER - auto block = cooperative_groups::this_thread_block(); - - constexpr int STAGES = 2; // 双缓冲 - __shared__ cuda::pipeline_shared_state ps; - auto pipe = cuda::make_pipeline(block, &ps); - - // 一次读取8个数据, 双buffer - __shared__ uint8_t fragment[2][128][8]; - // 取第0块 - int scale_offset = g_scale_q_offset_base + lane_id * 8; - if (scale_offset + 8 <= num_blocks) { - pipe.producer_acquire(); - cuda::memcpy_async(block, fragment[0][lane_id], scale_q + scale_offset, cuda::aligned_size_t<16>(16), pipe); - pipe.producer_commit(); - } else if (scale_offset < num_blocks) { - // 尾块处理:退化为标量copy - for (int i = 0; i < 8 && scale_offset + i < num_blocks; ++i) { - fragment[0][lane_id][i] = scale_q[scale_offset + i]; - } - } - -#pragma unroll - for (int i = 0; i < loop_times; ++i) { - - } -#else - uint8_t fragment[8]; - FP_T cache_res[8]; -#pragma unroll - for (int i = 0; i < loop_times; ++i) { - int scale_offset = g_scale_q_offset_base + i * 128 * 8 + lane_id * 8; - FP_T scale2 = absmax2[i * 128 * 8 + lane_id * 8 / group_size]; - if (scale_offset + 7 < num_blocks) { - LDST64BITS(fragment[0]) = LDST64BITS(*(scale_q + scale_offset)); -#pragma unroll - for (int j = 0; j < 8; ++j) { - cache_res[j] = ((FP_T*) shm_code2)[fragment[i]] * scale2; - } - LDST128BITS(output[scale_offset]) = LDST128BITS(cache_res[0]); - } else if (scale_offset < num_blocks) { - // 不够一组,退化为每个元素load - int remains = num_blocks - scale_offset; - for (int j = 0; j < remains; ++j) { - fragment[j] = (scale_q + scale_offset)[j]; - cache_res[j] = ((FP_T*) shm_code2)[fragment[i]] * scale2; - output[scale_offset + j] = cache_res[j]; - } - } - } -#endif -} - -void nf4_dequant(const QuantState& quant_state) { - // 解码scale - cutlass::HostTensor scale_q(cutlass::layout::PitchLinearCoord(quant_state.num_blocks, 1)); - cutlass::HostTensor code2(cutlass::layout::PitchLinearCoord(256, 1)); - cutlass::HostTensor absmax2(cutlass::layout::PitchLinearCoord(quant_state.num_groups, 1)); - - memcpy(scale_q.host_data(), quant_state.absmax_q, quant_state.absmax_q_len_in_bytes); - memcpy(code2.host_data(), quant_state.code2, 256 * 2); - memcpy(absmax2.host_data(), quant_state.absmax2, quant_state.absmax2_len_in_bytes); - - scale_q.sync_device(); - code2.sync_device(); - absmax2.sync_device(); - - constexpr int dequant_scale_per_thread_calc = 8; - dim3 dequant_scale_block_dim(128); - dim3 dequant_scale_grid_dim((quant_state.block_size + 128 * dequant_scale_per_thread_calc - 1) / 128 * dequant_scale_per_thread_calc); - // 解码权重 -} \ No newline at end of file From cb40f84e30bc19ce4bf0e600ea26851ed20f3c95 Mon Sep 17 00:00:00 2001 From: flashzxi Date: Thu, 26 Feb 2026 23:26:23 +0800 Subject: [PATCH 3/5] =?UTF-8?q?nvc=20=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 03_nf4_dequant/flashzxi/CMakeLists.txt | 8 ++++++- 03_nf4_dequant/flashzxi/src/main.cu | 4 ++-- .../flashzxi/src/nf4_dequant_naive.cu | 2 -- ..._dequant_warpN.cu => nf4_dequant_warp8.cu} | 21 +++++++++++-------- 4 files changed, 21 insertions(+), 14 deletions(-) rename 03_nf4_dequant/flashzxi/src/{nf4_dequant_warpN.cu => nf4_dequant_warp8.cu} (96%) diff --git a/03_nf4_dequant/flashzxi/CMakeLists.txt b/03_nf4_dequant/flashzxi/CMakeLists.txt index bf93ab2..3a3f635 100644 --- a/03_nf4_dequant/flashzxi/CMakeLists.txt +++ b/03_nf4_dequant/flashzxi/CMakeLists.txt @@ -9,10 +9,16 @@ set(CMAKE_CUDA_STANDARD_REQUIRED ON) set(CMAKE_CUDA_ARCHITECTURES native) +if(NOT CMAKE_CONFIGURATION_TYPES AND NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING "Build type" FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS + "Debug" "Release" "RelWithDebInfo" "MinSizeRel") +endif() + add_executable(nf4_dequant src/main.cu src/nf4_dequant_naive.cu - src/nf4_dequant_warpN.cu + src/nf4_dequant_warp8.cu ) target_include_directories(nf4_dequant PRIVATE diff --git a/03_nf4_dequant/flashzxi/src/main.cu b/03_nf4_dequant/flashzxi/src/main.cu index 7611646..161261c 100644 --- a/03_nf4_dequant/flashzxi/src/main.cu +++ b/03_nf4_dequant/flashzxi/src/main.cu @@ -10,9 +10,9 @@ int main() { int row = 10000; int col = 10000; - std::string file_prefix = std::string("/home/core_dump/Learning-CUDA/03_nf4_dequant/flashzxi/test/data/nf4_") + std::to_string(row) + "x" + std::to_string(col) + "_fp16"; + std::string file_prefix = std::string("/home/flashzxi/CLionProjects/Learning-CUDA/03_nf4_dequant/flashzxi/test/data/nf4_") + std::to_string(row) + "x" + std::to_string(col) + "_fp16"; auto conf = parse_quant_state(file_prefix + ".bin", - "/home/core_dump/Learning-CUDA/03_nf4_dequant/flashzxi/test/conf/blocksize64_fp16_T4.ini", + "/home/flashzxi/CLionProjects/Learning-CUDA/03_nf4_dequant/flashzxi/test/conf/blocksize64_fp16_T4.ini", file_prefix + "_w_dequant.bin"); // conf.print(); diff --git a/03_nf4_dequant/flashzxi/src/nf4_dequant_naive.cu b/03_nf4_dequant/flashzxi/src/nf4_dequant_naive.cu index 7082e56..5d5891f 100644 --- a/03_nf4_dequant/flashzxi/src/nf4_dequant_naive.cu +++ b/03_nf4_dequant/flashzxi/src/nf4_dequant_naive.cu @@ -3,8 +3,6 @@ // #include #include -#include "cutlass/core_io.h" -#include "cutlass/util/host_tensor.h" #include #include #include "quant_state.h" diff --git a/03_nf4_dequant/flashzxi/src/nf4_dequant_warpN.cu b/03_nf4_dequant/flashzxi/src/nf4_dequant_warp8.cu similarity index 96% rename from 03_nf4_dequant/flashzxi/src/nf4_dequant_warpN.cu rename to 03_nf4_dequant/flashzxi/src/nf4_dequant_warp8.cu index 7bce54a..fc7332d 100644 --- a/03_nf4_dequant/flashzxi/src/nf4_dequant_warpN.cu +++ b/03_nf4_dequant/flashzxi/src/nf4_dequant_warp8.cu @@ -3,8 +3,6 @@ // #include #include -#include "cutlass/core_io.h" -#include "cutlass/util/host_tensor.h" #include #include #include "quant_state.h" @@ -146,7 +144,7 @@ __global__ void dequant_nf4_elements_one_phase_warp8_batchN_kernel( int group_size, float offset, HFP_T* output) { - float kNF4[16] = { + constexpr float kNF4[16] = { -1.0000000f, -0.6961928f, -0.5250731f, -0.3949175f, -0.2844414f, -0.1847734f, -0.0910500f, 0.0000000f, 0.0795803f, 0.1609302f, 0.2461123f, 0.3379152f, @@ -157,12 +155,12 @@ __global__ void dequant_nf4_elements_one_phase_warp8_batchN_kernel( int bidx = blockIdx.x; int lane_id = threadIdx.x; - // load code2 - __shared__ float shm_code2_float[128]; + // load code2 不用shared memory更快 +// __shared__ float shm_code2_float[128]; - LDST32BITS(shm_code2_float[lane_id]) = LDST32BITS(code2[2 * lane_id]); - HFP_T* shm_code2 = (HFP_T *) shm_code2_float; - __syncthreads(); +// LDST32BITS(shm_code2_float[lane_id]) = LDST32BITS(code2[2 * lane_id]); +// HFP_T* shm_code2 = (HFP_T *) shm_code2_float; +// __syncthreads(); int block_offset = bidx * 128 * N; @@ -175,7 +173,12 @@ __global__ void dequant_nf4_elements_one_phase_warp8_batchN_kernel( int g_packed_weights_offset = block_offset + 8 * 128 * i + 8 * lane_id; int block_idx = g_packed_weights_offset / block_size; int group_idx = block_idx / group_size; - float scale = f162float(shm_code2[absmax_q[block_idx]] * absmax2[group_idx]) + offset; + + HFP_T h2[2]; + uint8_t q = absmax_q[block_idx]; + LDST32BITS(h2[0]) = LDST32BITS(code2[(q >> 1) << 1]); // 读 32-bit + HFP_T h = (q & 1) ? h2[1] : h2[0]; + float scale = f162float(h * absmax2[group_idx]) + offset; if (packed_weights + g_packed_weights_offset / 2 + 4 < packed_weights_end) { LDST32BITS(f_packed_weights[0]) = LDST32BITS(packed_weights[g_packed_weights_offset / 2]); #pragma unroll From a5e3482723b6ec7ebce296c4843024463e74bbea Mon Sep 17 00:00:00 2001 From: flashzxi Date: Sat, 14 Mar 2026 20:37:35 +0800 Subject: [PATCH 4/5] stash --- 04_hadamard_tc/flashzxi/CMakeLists.txt | 36 ++ 04_hadamard_tc/flashzxi/include/h16_bf16.inc | 16 + 04_hadamard_tc/flashzxi/include/h16_fp16.inc | 16 + 04_hadamard_tc/flashzxi/include/hadacore.hpp | 34 ++ 04_hadamard_tc/flashzxi/src/hadacore.cu | 451 +++++++++++++++++++ 04_hadamard_tc/flashzxi/src/main.cu | 12 + 6 files changed, 565 insertions(+) create mode 100644 04_hadamard_tc/flashzxi/CMakeLists.txt create mode 100644 04_hadamard_tc/flashzxi/include/h16_bf16.inc create mode 100644 04_hadamard_tc/flashzxi/include/h16_fp16.inc create mode 100644 04_hadamard_tc/flashzxi/include/hadacore.hpp create mode 100644 04_hadamard_tc/flashzxi/src/hadacore.cu create mode 100644 04_hadamard_tc/flashzxi/src/main.cu diff --git a/04_hadamard_tc/flashzxi/CMakeLists.txt b/04_hadamard_tc/flashzxi/CMakeLists.txt new file mode 100644 index 0000000..c93a60f --- /dev/null +++ b/04_hadamard_tc/flashzxi/CMakeLists.txt @@ -0,0 +1,36 @@ +cmake_minimum_required(VERSION 3.26) + +project(hadacore LANGUAGES CXX CUDA) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) + +set(CMAKE_CUDA_ARCHITECTURES 80) + +find_package(CUDAToolkit REQUIRED) + +include(FetchContent) + +FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git + GIT_TAG v3.4.0 +) + +FetchContent_MakeAvailable(cutlass) + +add_executable(hadacore + src/main.cu + src/hadacore.cu +) + +target_include_directories(hadacore PRIVATE + ${CMAKE_SOURCE_DIR}/include + ${cutlass_SOURCE_DIR}/include + ${cutlass_SOURCE_DIR}/tools/util/include + ${CUDAToolkit_INCLUDE_DIRS} +) + +set_target_properties(hadacore PROPERTIES + CUDA_SEPARABLE_COMPILATION OFF +) \ No newline at end of file diff --git a/04_hadamard_tc/flashzxi/include/h16_bf16.inc b/04_hadamard_tc/flashzxi/include/h16_bf16.inc new file mode 100644 index 0000000..379658c --- /dev/null +++ b/04_hadamard_tc/flashzxi/include/h16_bf16.inc @@ -0,0 +1,16 @@ +0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, +0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, +0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, +0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, +0x3f80, 0x3f80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, +0x3f80, 0xbf80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, +0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, +0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, +0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, +0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, +0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, +0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, +0x3f80, 0x3f80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, +0x3f80, 0xbf80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, +0x3f80, 0x3f80, 0xbf80, 0xbf80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, +0x3f80, 0xbf80, 0xbf80, 0x3f80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0xbf80, 0x3f80, 0x3f80, 0xbf80, 0x3f80, 0xbf80, 0xbf80, 0x3f80 \ No newline at end of file diff --git a/04_hadamard_tc/flashzxi/include/h16_fp16.inc b/04_hadamard_tc/flashzxi/include/h16_fp16.inc new file mode 100644 index 0000000..5fea632 --- /dev/null +++ b/04_hadamard_tc/flashzxi/include/h16_fp16.inc @@ -0,0 +1,16 @@ +0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, +0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, +0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, +0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, +0x3c00, 0x3c00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, +0x3c00, 0xbc00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, +0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, +0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, +0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, +0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, +0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, +0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, +0x3c00, 0x3c00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, +0x3c00, 0xbc00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, +0x3c00, 0x3c00, 0xbc00, 0xbc00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, +0x3c00, 0xbc00, 0xbc00, 0x3c00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0xbc00, 0x3c00, 0x3c00, 0xbc00, 0x3c00, 0xbc00, 0xbc00, 0x3c00 \ No newline at end of file diff --git a/04_hadamard_tc/flashzxi/include/hadacore.hpp b/04_hadamard_tc/flashzxi/include/hadacore.hpp new file mode 100644 index 0000000..fcd4897 --- /dev/null +++ b/04_hadamard_tc/flashzxi/include/hadacore.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(err)); \ + exit(1); \ + } \ + } while(0) + +namespace hadacore +{ + +void test_small(); +void test_large(); + +} // namespace hadacore \ No newline at end of file diff --git a/04_hadamard_tc/flashzxi/src/hadacore.cu b/04_hadamard_tc/flashzxi/src/hadacore.cu new file mode 100644 index 0000000..c3851fe --- /dev/null +++ b/04_hadamard_tc/flashzxi/src/hadacore.cu @@ -0,0 +1,451 @@ +// +// Created by core_dump on 3/14/26. +// +#include "hadacore.hpp" + +namespace hadacore +{ +using namespace cute; +const int M = 16; + +__device__ __constant__ uint16_t H16_fp16_bin[M * M] = { +#include "../include/h16_fp16.inc" +}; +__device__ __constant__ uint16_t H16_bf16_bin[M * M] = { +#include "../include/h16_bf16.inc" +}; + +__device__ __constant__ half_t* H16_fp16 = (half_t*) H16_fp16_bin; +__device__ __constant__ bfloat16_t* H16_bf16 = (bfloat16_t*) H16_bf16_bin; + +// 处理 64 < R_WIDTH < 256 +template +__global__ void hada_core_less_256(T* A, T* O_scope) { + + constexpr int ROWS = R_WIDTH / M; + + __shared__ __align__(32) int16_t smemA_bin[M * M]; + __shared__ __align__(32) int16_t smemhada_bin[M * M]; + __shared__ __align__(32) int16_t smemC_bin[M * M]; // 暂存 A * H + + T* smemA = (T*) smemA_bin; + T* smemC = (T*) smemC_bin; + T* smemhada = (T*) smemhada_bin; + + T* hada_ptr = nullptr; + + if constexpr (std::is_same_v) { + hada_ptr = H16_fp16; + } else { + hada_ptr = H16_bf16; + } + + auto gA = make_tensor( + make_gmem_ptr(A), + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{}) + ); + + auto sA = make_tensor( + make_smem_ptr(smemA), + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{}) + ); + + auto thr_sA = local_partition( + sA, + Layout>{}, + threadIdx.x + ); + + clear(thr_sA); // 清空 + __syncthreads(); + + auto gH = make_tensor( + make_gmem_ptr(hada_ptr), + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{}) + ); + + auto sH = make_tensor( + make_smem_ptr(smemhada), + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{}) + ); + + auto sC = make_tensor( + make_smem_ptr(smemC), + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{}) + ); + + using CopyAtom = Copy_Atom, T>; + + auto copyA = make_tiled_copy( + CopyAtom{}, + Layout, _2>>{}, + Layout>{} + ); + + auto copyH = make_tiled_copy( + CopyAtom{}, + Layout>{}, + Layout>{} + ); + + auto sA_sub = local_tile( + sA, + make_shape(Int{}, Int{}), + make_coord(0, 0) + ); + + auto thr_copy_a = copyA.get_slice(threadIdx.x); + auto tAgA = thr_copy_a.partition_S(gA); + auto tAsA = thr_copy_a.partition_D(sA_sub); + + auto thr_copy_h = copyH.get_slice(threadIdx.x); + auto tHgH = thr_copy_h.partition_S(gH); + auto tHsH = thr_copy_h.partition_D(sH); + + copy(tAgA, tAsA); + copy(tHgH, tHsH); + + if (threadIdx.x == 0) { + print_tensor(sA); + print_tensor(sH); + } + + __syncthreads(); + + using MMA_Atom_Arch = MMA_Atom; + + // 一个 16x8x16 atom,沿 M 方向铺 2 份 => 16x16x16 + auto mma = make_tiled_mma( + MMA_Atom_Arch{}, + Layout>{}, + Layout>{} + ); + + auto thr_mma = mma.get_slice(threadIdx.x); + + // ------------------------------- + // 1) 右乘 H: A x H -> C + // ------------------------------- + + auto tCsA = thr_mma.partition_A(sA); // A + auto tCsB = thr_mma.partition_B(sH); // H + auto tCsC = thr_mma.partition_C(sC); // C + + auto tCrC = thr_mma.make_fragment_C(tCsC); + clear(tCrC); + + gemm(mma, tCsA, tCsB, tCrC); + + copy(tCrC, tCsC); + + __syncthreads(); + + if (threadIdx.x == 0) { + print_tensor(sC); + } + + auto sCt = make_tensor( + make_smem_ptr(smemC), + make_shape(Int{}, Int{}), + make_stride(Int<1>{}, Int{}) + ); // sC 的转置 view + + auto t2sA = thr_mma.partition_A(sH); + auto t2sB = thr_mma.partition_B(sCt); + auto t2sC = thr_mma.partition_C(sA); + + auto t2rC = thr_mma.make_fragment_C(t2sC); + clear(t2rC); + + gemm(mma, t2sA, t2sB, t2rC); + + copy(t2rC, t2sC); + __syncthreads(); + + if (threadIdx.x == 0) { + print_tensor(sA); // row-major 查看结果 + } + copy(tAsA, tAgA); + + __syncthreads(); +} + +// 每个block负责计算一行 +// 一次计算256,一个warp计算 CHUNKS 个256 +// 一个block R_WIDTH / 256 / CHUNKS 个warp +template +__global__ void hadacore_large(const T* A) { + + constexpr int WARPS = R_WIDTH / 256 / CHUNKS; + extern __shared__ __align__(16) char smemA[]; + __shared__ __align__(32) int16_t smemhada_bin[M * M]; + + T* smemA_total = (T*) smemA; + T* smemhada = (T*)smemhada_bin; + T* hada_ptr = nullptr; + if constexpr (std::is_same_v) { + hada_ptr = H16_fp16; + } else { + hada_ptr = H16_bf16; + } + + auto gA_total = make_tensor( + make_gmem_ptr(A), + make_shape(Int{}, Int{}, Int{}), + make_stride(Int{}, Int{}, Int<1>{}) + ); + + auto sA_total = make_tensor( + make_smem_ptr(smemA_total), + make_shape(Int{}, Int{}, Int{}), + make_stride(Int{}, Int{}, Int<1>{}) + ); + + auto gH = make_tensor( + make_gmem_ptr(hada_ptr), + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{}) + ); + + auto sH = make_tensor( + make_smem_ptr(smemhada), + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{}) + ); + + auto sA = sA_total(threadIdx.y * CHUNKS, _, _); + + // 每个线程 load 8 个 elements + using CopyAtom = Copy_Atom, T>; + + auto copyA = make_tiled_copy( + CopyAtom{}, + Layout>{}, + Layout>{} + ); + + auto thr_copy_a = copyA.get_slice(threadIdx.x); + + auto tAgA = thr_copy_a.partition_S(gA_total(threadIdx.y * CHUNKS, _, _)); + auto tAsA = thr_copy_a.partition_D(sA); + + auto tHgH = thr_copy_a.partition_S(gH); + auto tHsH = thr_copy_a.partition_D(sH); + + if (threadIdx.y == 0) { + copy(tHgH, tHsH); + } + copy(tAgA, tAsA); + __syncwarp(); + + cp_async_fence(); + + using MMA_Atom_Arch = MMA_Atom; + + // 一个 16x8x16 atom,沿 M 方向铺 2 份 => 16x16x16 + auto mma = make_tiled_mma( + MMA_Atom_Arch{}, + Layout>{}, + Layout>{} + ); + + for (int loop = 1; loop < CHUNKS; ++loop) { + // 先 load 下一批 A,再计算 + auto sA_back = sA_total(threadIdx.y * CHUNKS + loop, _, _); + auto tAgA_back = thr_copy_a.partition_S( + gA_total(threadIdx.y * CHUNKS + loop, _, _) + ); + auto tAsA_back = thr_copy_a.partition_D(sA_back); + copy(tAgA_back, tAsA_back); + + cp_async_wait<0>(); + + if (threadIdx.y == 0 && threadIdx.x == 0) + { + print_tensor(gA_total(threadIdx.y * CHUNKS + loop - 1, _, _)); + print_tensor(sA); + print_tensor(sH); + } + // 计算 H * (A * H) + auto thr_mma = mma.get_slice(threadIdx.x); + + // 1) 右乘 H: A x H -> C + auto tCsA = thr_mma.partition_A(sA); + auto tCsB = thr_mma.partition_B(sH); + auto tCsC = thr_mma.partition_C(sA); + + auto tCrC = thr_mma.make_fragment_C(tCsC); + + clear(tCrC); + gemm(mma, tCsA, tCsB, tCrC); + copy(tCrC, tCsC); + __syncwarp(); + if (threadIdx.y == 0 && threadIdx.x == 0 && loop == 1) + { + print_tensor(sA); + } + + // 2) 左乘 H: H x C -> A + auto sAt = make_tensor(sA.data(), + make_shape(Int{}, Int{}), + make_stride(Int<1>{}, Int{})); + auto tCsH = thr_mma.partition_A(sH); + auto tCsHC = thr_mma.partition_B(sAt); + auto tCsC2 = thr_mma.partition_C(sA); + auto tCrC2 = thr_mma.make_fragment_C(tCsC2); + + clear(tCrC2); + gemm(mma, tCsH, tCsHC, tCrC2); + copy(tCrC2, tCsC2); + + __syncwarp(); + // 完成数据 load 再进行下一批 work + cp_async_fence(); + sA = sA_back; + } + cp_async_wait<0>(); + // 计算 H * (A * H) + auto thr_mma = mma.get_slice(threadIdx.x); + + // 1) 右乘 H: A x H -> C + auto tCsA = thr_mma.partition_A(sA); + auto tCsB = thr_mma.partition_B(sH); + auto tCsC = thr_mma.partition_C(sA); + + auto tCrC = thr_mma.make_fragment_C(tCsC); + + clear(tCrC); + gemm(mma, tCsA, tCsB, tCrC); + copy(tCrC, tCsC); + __syncwarp(); + + // 2) 左乘 H: H x C -> A + auto sAt = make_tensor(sA.data(), + make_shape(Int{}, Int{}), + make_stride(Int<1>{}, Int{})); + auto tCsH = thr_mma.partition_A(sH); + auto tCsHC = thr_mma.partition_B(sAt); + auto tCsC2 = thr_mma.partition_C(sA); + auto tCrC2 = thr_mma.make_fragment_C(tCsC2); + + clear(tCrC2); + gemm(mma, tCsH, tCsHC, tCrC2); + copy(tCrC2, tCsC2); + + // 需要block的全部thread同步了 + __syncthreads(); + + new + +} + +void test_small() +{ + constexpr int R_WIDTH = 128; // 8 * 16 + constexpr int ROWS = R_WIDTH / M; // 8 + + // 准备输入数据 (8行16列) + std::vector A_h(R_WIDTH); + for (int i = 0; i < R_WIDTH; ++i) { + A_h[i] = half_t(i / 100.0f); + } + + // 打印输入数据 + printf("Input A (8x16):\n"); + for (int r = 0; r < ROWS; ++r) { + for (int c = 0; c < M; ++c) { + printf("%6.2f ", float(A_h[r * M + c])); + } + printf("\n"); + } + + // 分配 GPU 内存 + half_t *A_d, *O_d; + cudaMalloc(&A_d, R_WIDTH * sizeof(half_t)); + cudaMalloc(&O_d, R_WIDTH * sizeof(half_t)); + + // 拷贝数据到 GPU + cudaMemcpy(A_d, A_h.data(), R_WIDTH * sizeof(half_t), cudaMemcpyHostToDevice); + + // 调用 kernel + hada_core_less_256<<<1, 32>>>(A_d, O_d); + + // 等待完成 + cudaDeviceSynchronize(); + CUDA_CHECK(cudaGetLastError()); + + // 拷贝结果回主机 + std::vector result(R_WIDTH); + cudaMemcpy(result.data(), A_d, R_WIDTH * sizeof(half_t), cudaMemcpyDeviceToHost); + + // 打印结果 + printf("\nOutput A after H * A * H (8x16):\n"); + for (int r = 0; r < ROWS; ++r) { + for (int c = 0; c < M; ++c) { + printf("%6.2f ", float(result[r * M + c])); + } + printf("\n"); + } + + // 释放内存 + cudaFree(A_d); + cudaFree(O_d); +} + +void test_large() +{ + constexpr int R_WIDTH = 1024; // 总行宽 + constexpr int CHUNKS = 2; // 每个 warp 处理的 chunk 数 + constexpr int WARPS = R_WIDTH / 256 / CHUNKS; // = 2 + + // 准备输入数据 (512 = 32行 x 16列) + std::vector A_h(R_WIDTH); + for (int i = 0; i < R_WIDTH; ++i) + { + A_h[i] = half_t(i / 100.0f); // 0,1,2,...,15,0,1,2,... + } + // 分配 GPU 内存 + half_t *A_d; + cudaMalloc(&A_d, R_WIDTH * sizeof(half_t)); + + // 拷贝数据到 GPU + cudaMemcpy(A_d, A_h.data(), R_WIDTH * sizeof(half_t), cudaMemcpyHostToDevice); + + // 计算 dynamic shared memory 大小 + // 每个 chunk 是 16x16,每个 warp 处理 CHUNKS 个 + size_t smem_size = WARPS * CHUNKS * M * M * sizeof(half_t); + + printf("\nLaunching kernel: R_WIDTH=%d, CHUNKS=%d, WARPS=%d\n", R_WIDTH, CHUNKS, WARPS); + printf("Block dim: (%d, %d, 1), Dynamic smem: %zu bytes\n\n", 32, WARPS, smem_size); + + // 调用 kernel + dim3 block(32, WARPS); + hadacore_large<<<1, block, smem_size>>>(A_d); + + // 等待完成 + cudaDeviceSynchronize(); + CUDA_CHECK(cudaGetLastError()); + + // 拷贝结果回主机 + std::vector result(R_WIDTH); + cudaMemcpy(result.data(), A_d, R_WIDTH * sizeof(half_t), cudaMemcpyDeviceToHost); + + // 打印结果 (前32行) + printf("Output A after H * A * H (first 32x16):\n"); + for (int r = 0; r < 32; ++r) { + for (int c = 0; c < M; ++c) { + printf("%6.1f ", float(result[r * M + c])); + } + printf("\n"); + } + + // 释放内存 + cudaFree(A_d); +} +} + diff --git a/04_hadamard_tc/flashzxi/src/main.cu b/04_hadamard_tc/flashzxi/src/main.cu new file mode 100644 index 0000000..3e230d4 --- /dev/null +++ b/04_hadamard_tc/flashzxi/src/main.cu @@ -0,0 +1,12 @@ +// +// Created by core_dump on 3/14/26. +// +#include "hadacore.hpp" + +int main() +{ + // hadacore::test_small(); + printf("\n========================================\n\n"); + hadacore::test_large(); + return 0; +} \ No newline at end of file From 9b242ec42ebdd8f59babbdfb74397838412b41b3 Mon Sep 17 00:00:00 2001 From: flashzxi Date: Mon, 16 Mar 2026 22:53:05 +0800 Subject: [PATCH 5/5] update --- 03_nf4_dequant/flashzxi/Report.md | 16 + 03_nf4_dequant/flashzxi/include/nf4_dequant.h | 2 +- 03_nf4_dequant/flashzxi/src/main.cu | 6 +- .../flashzxi/src/nf4_dequant_warp8.cu | 4 +- 03_nf4_dequant/flashzxi/test/data/baseline.py | 4 +- 04_hadamard_tc/flashzxi/src/hadacore.cu | 285 +++++++----------- 04_hadamard_tc/flashzxi/src/main.cu | 24 +- 7 files changed, 146 insertions(+), 195 deletions(-) create mode 100644 03_nf4_dequant/flashzxi/Report.md diff --git a/03_nf4_dequant/flashzxi/Report.md b/03_nf4_dequant/flashzxi/Report.md new file mode 100644 index 0000000..258aada --- /dev/null +++ b/03_nf4_dequant/flashzxi/Report.md @@ -0,0 +1,16 @@ +## NF4 反量化 +author: flashzxi + +本项目是利用cuda高效计算nf4反量化,对比bitsandbytes 实现 + +本项目的假设: +每个block大小为64个元素 + +二级量化每个group包含256个block. + +## 实现 +总共实现了三个版本,一个最简单的naive版本,一个二级反量化和一级反量化分开计算的版本以及最终的单独kernel解两层反量化的版本。其中naive版本在`src/nf4_dequant_naive.cu`,其余两个版本都在`src/nf4_dequant_warp8.cu` + + + +开发工程中,我尝试 \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/include/nf4_dequant.h b/03_nf4_dequant/flashzxi/include/nf4_dequant.h index 1c8265c..85dae58 100644 --- a/03_nf4_dequant/flashzxi/include/nf4_dequant.h +++ b/03_nf4_dequant/flashzxi/include/nf4_dequant.h @@ -7,4 +7,4 @@ #include "quant_state.h" void nf4_dequant_naive(const QuantState& quant_state, __half* output); void nf4_dequant_warp8_batch32_two_phase(const QuantState& quant_state, __half* output); -void nf4_dequant_warp8_batch32_one_phase(const QuantState& quant_state, __half* output); \ No newline at end of file +void nf4_dequant_warp8_batch8_one_phase(const QuantState& quant_state, __half* output); \ No newline at end of file diff --git a/03_nf4_dequant/flashzxi/src/main.cu b/03_nf4_dequant/flashzxi/src/main.cu index 161261c..ce33734 100644 --- a/03_nf4_dequant/flashzxi/src/main.cu +++ b/03_nf4_dequant/flashzxi/src/main.cu @@ -10,9 +10,9 @@ int main() { int row = 10000; int col = 10000; - std::string file_prefix = std::string("/home/flashzxi/CLionProjects/Learning-CUDA/03_nf4_dequant/flashzxi/test/data/nf4_") + std::to_string(row) + "x" + std::to_string(col) + "_fp16"; + std::string file_prefix = std::string("/home/core_dump/Learning-CUDA/03_nf4_dequant/flashzxi/nf4_") + std::to_string(row) + "x" + std::to_string(col) + "_fp16"; auto conf = parse_quant_state(file_prefix + ".bin", - "/home/flashzxi/CLionProjects/Learning-CUDA/03_nf4_dequant/flashzxi/test/conf/blocksize64_fp16_T4.ini", + "/home/core_dump/Learning-CUDA/03_nf4_dequant/flashzxi/test/conf/blocksize64_fp16_T4.ini", file_prefix + "_w_dequant.bin"); // conf.print(); @@ -25,7 +25,7 @@ int main() { std::cout << std::endl; __half* ans = new __half[conf.num_elements]; - nf4_dequant_warp8_batch32_one_phase(conf, ans); + nf4_dequant_warp8_batch8_one_phase(conf, ans); float max_diff = 0.f; for (int i = 0; i < conf.num_rows; i++) { diff --git a/03_nf4_dequant/flashzxi/src/nf4_dequant_warp8.cu b/03_nf4_dequant/flashzxi/src/nf4_dequant_warp8.cu index fc7332d..118aa05 100644 --- a/03_nf4_dequant/flashzxi/src/nf4_dequant_warp8.cu +++ b/03_nf4_dequant/flashzxi/src/nf4_dequant_warp8.cu @@ -314,8 +314,8 @@ void nf4_dequant_warp8_batch32_two_phase(const QuantState& quant_state, __half* CUDA_CHECK(cudaFree(unpacked_weights_s)); } -void nf4_dequant_warp8_batch32_one_phase(const QuantState& quant_state, __half* output) { - constexpr int PROCESS_SIZE = 32; +void nf4_dequant_warp8_batch8_one_phase(const QuantState& quant_state, __half* output) { + constexpr int PROCESS_SIZE = 8; uint8_t* absmax_q_s; __half* code2_s; diff --git a/03_nf4_dequant/flashzxi/test/data/baseline.py b/03_nf4_dequant/flashzxi/test/data/baseline.py index 4ca6f79..857e548 100644 --- a/03_nf4_dequant/flashzxi/test/data/baseline.py +++ b/03_nf4_dequant/flashzxi/test/data/baseline.py @@ -130,8 +130,8 @@ def save_dequant_result(path: str, deq: torch.Tensor, shape, out_dtype=torch.flo if __name__ == "__main__": torch.manual_seed(0) - torch.manual_seed(1234) # CPU RNG - torch.cuda.manual_seed_all(1234) # 所有 GPU RNG(单卡也可用) + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) row = 10000 col = 10000 W = torch.randn(row, col, device="cuda", dtype=torch.float16) diff --git a/04_hadamard_tc/flashzxi/src/hadacore.cu b/04_hadamard_tc/flashzxi/src/hadacore.cu index c3851fe..dabd2e5 100644 --- a/04_hadamard_tc/flashzxi/src/hadacore.cu +++ b/04_hadamard_tc/flashzxi/src/hadacore.cu @@ -18,180 +18,60 @@ __device__ __constant__ uint16_t H16_bf16_bin[M * M] = { __device__ __constant__ half_t* H16_fp16 = (half_t*) H16_fp16_bin; __device__ __constant__ bfloat16_t* H16_bf16 = (bfloat16_t*) H16_bf16_bin; -// 处理 64 < R_WIDTH < 256 -template -__global__ void hada_core_less_256(T* A, T* O_scope) { - - constexpr int ROWS = R_WIDTH / M; - - __shared__ __align__(32) int16_t smemA_bin[M * M]; - __shared__ __align__(32) int16_t smemhada_bin[M * M]; - __shared__ __align__(32) int16_t smemC_bin[M * M]; // 暂存 A * H - - T* smemA = (T*) smemA_bin; - T* smemC = (T*) smemC_bin; - T* smemhada = (T*) smemhada_bin; - - T* hada_ptr = nullptr; - - if constexpr (std::is_same_v) { - hada_ptr = H16_fp16; - } else { - hada_ptr = H16_bf16; - } - - auto gA = make_tensor( - make_gmem_ptr(A), - make_shape(Int{}, Int{}), - make_stride(Int{}, Int<1>{}) - ); - - auto sA = make_tensor( - make_smem_ptr(smemA), - make_shape(Int{}, Int{}), - make_stride(Int{}, Int<1>{}) - ); - - auto thr_sA = local_partition( - sA, - Layout>{}, - threadIdx.x - ); - - clear(thr_sA); // 清空 - __syncthreads(); - - auto gH = make_tensor( - make_gmem_ptr(hada_ptr), - make_shape(Int{}, Int{}), - make_stride(Int{}, Int<1>{}) - ); - - auto sH = make_tensor( - make_smem_ptr(smemhada), - make_shape(Int{}, Int{}), - make_stride(Int{}, Int<1>{}) - ); - - auto sC = make_tensor( - make_smem_ptr(smemC), - make_shape(Int{}, Int{}), - make_stride(Int{}, Int<1>{}) - ); - - using CopyAtom = Copy_Atom, T>; - - auto copyA = make_tiled_copy( - CopyAtom{}, - Layout, _2>>{}, - Layout>{} - ); - - auto copyH = make_tiled_copy( - CopyAtom{}, - Layout>{}, - Layout>{} - ); - - auto sA_sub = local_tile( - sA, - make_shape(Int{}, Int{}), - make_coord(0, 0) - ); - - auto thr_copy_a = copyA.get_slice(threadIdx.x); - auto tAgA = thr_copy_a.partition_S(gA); - auto tAsA = thr_copy_a.partition_D(sA_sub); - - auto thr_copy_h = copyH.get_slice(threadIdx.x); - auto tHgH = thr_copy_h.partition_S(gH); - auto tHsH = thr_copy_h.partition_D(sH); - - copy(tAgA, tAsA); - copy(tHgH, tHsH); - - if (threadIdx.x == 0) { - print_tensor(sA); - print_tensor(sH); - } - - __syncthreads(); - - using MMA_Atom_Arch = MMA_Atom; - - // 一个 16x8x16 atom,沿 M 方向铺 2 份 => 16x16x16 - auto mma = make_tiled_mma( - MMA_Atom_Arch{}, - Layout>{}, - Layout>{} - ); - - auto thr_mma = mma.get_slice(threadIdx.x); - - // ------------------------------- - // 1) 右乘 H: A x H -> C - // ------------------------------- - - auto tCsA = thr_mma.partition_A(sA); // A - auto tCsB = thr_mma.partition_B(sH); // H - auto tCsC = thr_mma.partition_C(sC); // C - - auto tCrC = thr_mma.make_fragment_C(tCsC); - clear(tCrC); - - gemm(mma, tCsA, tCsB, tCrC); - - copy(tCrC, tCsC); - - __syncthreads(); - - if (threadIdx.x == 0) { - print_tensor(sC); - } - - auto sCt = make_tensor( - make_smem_ptr(smemC), - make_shape(Int{}, Int{}), - make_stride(Int<1>{}, Int{}) - ); // sC 的转置 view - - auto t2sA = thr_mma.partition_A(sH); - auto t2sB = thr_mma.partition_B(sCt); - auto t2sC = thr_mma.partition_C(sA); - - auto t2rC = thr_mma.make_fragment_C(t2sC); - clear(t2rC); - - gemm(mma, t2sA, t2sB, t2rC); - - copy(t2rC, t2sC); - __syncthreads(); +// 对角Hadamard矩阵 +__device__ __constant__ uint16_t H2_diag_fp16_bin[M * M] = { +#include "../include/h2_diag_fp16.inc" +}; +__device__ __constant__ uint16_t H2_diag_bf16_bin[M * M] = { +#include "../include/h2_diag_bf16.inc" +}; +__device__ __constant__ half_t* H2_diag_fp16 = (half_t*) H2_diag_fp16_bin; +__device__ __constant__ bfloat16_t* H2_diag_bf16 = (bfloat16_t*) H2_diag_bf16_bin; - if (threadIdx.x == 0) { - print_tensor(sA); // row-major 查看结果 - } - copy(tAsA, tAgA); +__device__ __constant__ uint16_t H4_diag_fp16_bin[M * M] = { +#include "../include/h4_diag_fp16.inc" +}; +__device__ __constant__ uint16_t H4_diag_bf16_bin[M * M] = { +#include "../include/h4_diag_bf16.inc" +}; +__device__ __constant__ half_t* H4_diag_fp16 = (half_t*) H4_diag_fp16_bin; +__device__ __constant__ bfloat16_t* H4_diag_bf16 = (bfloat16_t*) H4_diag_bf16_bin; - __syncthreads(); -} +__device__ __constant__ uint16_t H8_diag_fp16_bin[M * M] = { +#include "../include/h8_diag_fp16.inc" +}; +__device__ __constant__ uint16_t H8_diag_bf16_bin[M * M] = { +#include "../include/h8_diag_bf16.inc" +}; +__device__ __constant__ half_t* H8_diag_fp16 = (half_t*) H8_diag_fp16_bin; +__device__ __constant__ bfloat16_t* H8_diag_bf16 = (bfloat16_t*) H8_diag_bf16_bin; // 每个block负责计算一行 // 一次计算256,一个warp计算 CHUNKS 个256 // 一个block R_WIDTH / 256 / CHUNKS 个warp template -__global__ void hadacore_large(const T* A) { +__global__ void hadacore_less_than_4096(T* A) { constexpr int WARPS = R_WIDTH / 256 / CHUNKS; extern __shared__ __align__(16) char smemA[]; - __shared__ __align__(32) int16_t smemhada_bin[M * M]; + __shared__ __align__(32) int16_t smemhada_bin1[M * M]; + __shared__ __align__(32) int16_t smemhada_bin2[M * M]; T* smemA_total = (T*) smemA; - T* smemhada = (T*)smemhada_bin; - T* hada_ptr = nullptr; + T* smemhada1 = (T*)smemhada_bin1; + T* smemhada2 = (T*)smemhada_bin2; + T* hada1_ptr = nullptr; + T* hada2_ptr = nullptr; if constexpr (std::is_same_v) { - hada_ptr = H16_fp16; + hada1_ptr = H16_fp16; } else { - hada_ptr = H16_bf16; + hada1_ptr = H16_bf16; + } + + constexpr int log_r_width = 31 - __builtin_clz(R_WIDTH); + if (log_r_width > 8) + { + if(R_WIDTH) } auto gA_total = make_tensor( @@ -206,14 +86,20 @@ __global__ void hadacore_large(const T* A) { make_stride(Int{}, Int{}, Int<1>{}) ); - auto gH = make_tensor( - make_gmem_ptr(hada_ptr), + auto gH1 = make_tensor( + make_gmem_ptr(hada1_ptr), + make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{}) + ); + + auto sH1 = make_tensor( + make_smem_ptr(smemhada1), make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{}) ); - auto sH = make_tensor( - make_smem_ptr(smemhada), + auto sH2 = make_tensor( + make_smem_ptr(smemhada2), make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{}) ); @@ -234,17 +120,25 @@ __global__ void hadacore_large(const T* A) { auto tAgA = thr_copy_a.partition_S(gA_total(threadIdx.y * CHUNKS, _, _)); auto tAsA = thr_copy_a.partition_D(sA); - auto tHgH = thr_copy_a.partition_S(gH); - auto tHsH = thr_copy_a.partition_D(sH); + auto tHgH1 = thr_copy_a.partition_S(gH1); + auto tHsH1 = thr_copy_a.partition_D(sH1); + + auto tHgH2 = thr_copy_a.partition_S(gH2); + auto tHsH2 = thr_copy_a.partition_D(sH2); if (threadIdx.y == 0) { - copy(tHgH, tHsH); + copy(tHgH1, tHsH1); + if (gH2 != nullptr) + { + copy(tHgH2, tHsH2); + } + } + if (threadIdx.x < R_WIDTH / 16) + { + copy(tAgA, tAsA); } - copy(tAgA, tAsA); __syncwarp(); - cp_async_fence(); - using MMA_Atom_Arch = MMA_Atom; // 一个 16x8x16 atom,沿 M 方向铺 2 份 => 16x16x16 @@ -269,14 +163,14 @@ __global__ void hadacore_large(const T* A) { { print_tensor(gA_total(threadIdx.y * CHUNKS + loop - 1, _, _)); print_tensor(sA); - print_tensor(sH); + print_tensor(sH1); } // 计算 H * (A * H) auto thr_mma = mma.get_slice(threadIdx.x); // 1) 右乘 H: A x H -> C auto tCsA = thr_mma.partition_A(sA); - auto tCsB = thr_mma.partition_B(sH); + auto tCsB = thr_mma.partition_B(sH1); auto tCsC = thr_mma.partition_C(sA); auto tCrC = thr_mma.make_fragment_C(tCsC); @@ -294,7 +188,7 @@ __global__ void hadacore_large(const T* A) { auto sAt = make_tensor(sA.data(), make_shape(Int{}, Int{}), make_stride(Int<1>{}, Int{})); - auto tCsH = thr_mma.partition_A(sH); + auto tCsH = thr_mma.partition_A(sH1); auto tCsHC = thr_mma.partition_B(sAt); auto tCsC2 = thr_mma.partition_C(sA); auto tCrC2 = thr_mma.make_fragment_C(tCsC2); @@ -314,7 +208,7 @@ __global__ void hadacore_large(const T* A) { // 1) 右乘 H: A x H -> C auto tCsA = thr_mma.partition_A(sA); - auto tCsB = thr_mma.partition_B(sH); + auto tCsB = thr_mma.partition_B(sH1); auto tCsC = thr_mma.partition_C(sA); auto tCrC = thr_mma.make_fragment_C(tCsC); @@ -323,12 +217,20 @@ __global__ void hadacore_large(const T* A) { gemm(mma, tCsA, tCsB, tCrC); copy(tCrC, tCsC); __syncwarp(); + if (R_WIDTH < 256) + { + if (threadIdx.x < R_WIDTH / 16) + { + copy(tAsA, tAgA); + } + return; + } // 2) 左乘 H: H x C -> A auto sAt = make_tensor(sA.data(), make_shape(Int{}, Int{}), make_stride(Int<1>{}, Int{})); - auto tCsH = thr_mma.partition_A(sH); + auto tCsH = thr_mma.partition_A(sH1); auto tCsHC = thr_mma.partition_B(sAt); auto tCsC2 = thr_mma.partition_C(sA); auto tCrC2 = thr_mma.make_fragment_C(tCsC2); @@ -337,11 +239,34 @@ __global__ void hadacore_large(const T* A) { gemm(mma, tCsH, tCsHC, tCrC2); copy(tCrC2, tCsC2); - // 需要block的全部thread同步了 - __syncthreads(); + auto origin_layout = make_layout( + make_shape(Int{}, Int<256>{}), + make_stride(Int<256>{}, Int<1>{})); + auto new_view = make_layout( + make_shape(Int<16>{}, Int<16>{}), + make_stride(Int<16>{}, Int<1>{})); + auto real_layout = composition(origin_layout, new_view); + + for (int i = 0; i < CHUNKS; ++i) + { + int cols = 256 / CHUNKS * WARPS; + auto new_tensor = make_tensor( + make_smem_ptr(smemA_total + cols), real_layout + ); - new + auto tCsA = thr_mma.partition_A(new_tensor); + auto tCsB = thr_mma.partition_B(sH2); + auto tCsC = thr_mma.partition_C(new_tensor); + + auto tCrC = thr_mma.make_fragment_C(tCsC); + clear(tCrC); + gemm(mma, tCsA, tCsB, tCrC); + copy(tCrC, tCsC); + } + + // 需要block的全部thread同步了 + __syncthreads(); } void test_small() @@ -418,14 +343,14 @@ void test_large() // 计算 dynamic shared memory 大小 // 每个 chunk 是 16x16,每个 warp 处理 CHUNKS 个 - size_t smem_size = WARPS * CHUNKS * M * M * sizeof(half_t); + size_t smem_size = std::max(WARPS * CHUNKS * M * M * sizeof(half_t), 16 * sizeof(half_t)); printf("\nLaunching kernel: R_WIDTH=%d, CHUNKS=%d, WARPS=%d\n", R_WIDTH, CHUNKS, WARPS); printf("Block dim: (%d, %d, 1), Dynamic smem: %zu bytes\n\n", 32, WARPS, smem_size); // 调用 kernel dim3 block(32, WARPS); - hadacore_large<<<1, block, smem_size>>>(A_d); + hadacore_less_than_4096<<<1, block, smem_size>>>(A_d); // 等待完成 cudaDeviceSynchronize(); diff --git a/04_hadamard_tc/flashzxi/src/main.cu b/04_hadamard_tc/flashzxi/src/main.cu index 3e230d4..a32b829 100644 --- a/04_hadamard_tc/flashzxi/src/main.cu +++ b/04_hadamard_tc/flashzxi/src/main.cu @@ -2,11 +2,21 @@ // Created by core_dump on 3/14/26. // #include "hadacore.hpp" +#include +using namespace cute; -int main() -{ - // hadacore::test_small(); - printf("\n========================================\n\n"); - hadacore::test_large(); - return 0; -} \ No newline at end of file +int main() { + auto layout = make_layout(make_shape(Int<4>{}, Int<256>{}), + make_stride(Int<256>{}, Int<1>{})); + + auto B = make_layout(Shape<_16, _16>{}, Stride<_16, _1>{}); + auto new_layout = composition(layout, B); + print_layout(new_layout); // 直接打印二维“坐标 -> index”表 +} +// int main() +// { +// // hadacore::test_small(); +// printf("\n========================================\n\n"); +// hadacore::test_large(); +// return 0; +// } \ No newline at end of file