Skip to content

feat: add rust kernels library for loading kernels#421

Open
drbh wants to merge 2 commits intomainfrom
add-kernels-rs
Open

feat: add rust kernels library for loading kernels#421
drbh wants to merge 2 commits intomainfrom
add-kernels-rs

Conversation

@drbh
Copy link
Copy Markdown
Collaborator

@drbh drbh commented Mar 31, 2026

This PR adds a new client library for loading hf kernels in rust. This allow tvmffi based kernels to be called from rust and optionally integrates with candle for a better tensor ux.

Example usage with candle

use candle_core::{Device, Tensor};
use kernels::Result;
use kernels::candle::CallKernel;

fn main() -> Result<()> {
    let activation = kernels::candle::get_kernel("drbh/relu-tvm", 1)?;
    let device = activation.device()?;
    println!("Backend: {}", activation.backend());

    let x = Tensor::new(&[-1.0f32, 2.0, -3.0, 4.0, -0.5, 0.0, 1.5, -2.5], &device)?;
    let y = Tensor::zeros_like(&x)?;
    activation.call("relu", &[&y, &x])?;

    let result = y.to_vec1::<f32>()?;
    let expected = Tensor::new(&*x.to_vec1::<f32>()?, &Device::Cpu)?
        .relu()?
        .to_vec1::<f32>()?;

    println!("Input:    {:?}", x.to_vec1::<f32>()?);
    println!("TVM FFI:  {result:?}");
    println!("Candle:   {expected:?}");
    assert_eq!(result, expected);
    println!("OK");
    Ok(())
}

repo with candle and non candle examples https://github.com/drbh/hf-kernels-rust

@drbh drbh marked this pull request as ready for review April 1, 2026 15:01
Comment on lines +74 to +93
pub fn detect_cuda_version() -> Option<String> {
cuda_version_from_smi().or_else(cuda_version_from_nvcc)
}

fn cuda_version_from_smi() -> Option<String> {
let output = Command::new("nvidia-smi").output().ok()?;
if !output.status.success() {
return None;
}
let stdout = String::from_utf8_lossy(&output.stdout);
let rest = stdout.split("CUDA Version:").nth(1)?;
Some(rest.split_whitespace().next()?.to_string())
}

fn cuda_version_from_nvcc() -> Option<String> {
let output = Command::new("nvcc").arg("--version").output().ok()?;
let stdout = String::from_utf8_lossy(&output.stdout);
let after = stdout.split("release ").nth(1)?;
Some(after.split(',').next()?.trim().to_string())
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not be the same as the library that a framework is compiled against and dynamically loads. Also, nvidia-smi gives the driver library version, not the CUDA runtime version. We need to get it from cudart, e.g. see:

def _get_cuda() -> Optional[CUDA]:

libloading seems to be the most widely used library for dlopen:

https://github.com/nagisa/rust_libloading/

Comment on lines +15 to +24
pub fn candle_device(self) -> Result<Device> {
match self {
BackendKind::Cpu => Ok(Device::Cpu),
#[cfg(feature = "candle-cuda")]
BackendKind::Cuda => Device::new_cuda(0).map_err(Into::into),
#[cfg(not(feature = "candle-cuda"))]
BackendKind::Cuda => Ok(Device::Cpu),
BackendKind::Xpu => Ok(Device::Cpu),
}
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be TryFrom<BackendDevice> for Device. Not 100% sure if it works with the coherency rules, since it's a different mod in the same crate. But I think it should.

Comment on lines +26 to +34
pub fn candle_supported(self) -> Self {
match self {
#[cfg(feature = "candle-cuda")]
BackendKind::Cuda => BackendKind::Cuda,
#[cfg(not(feature = "candle-cuda"))]
BackendKind::Cuda => BackendKind::Cpu,
other => other,
}
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function name is not very descriptive, maybe to_candle_supported?

Comment on lines +43 to +44
#[allow(unreachable_patterns)]
_ => BackendKind::Cpu,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to explicitly enumerate the other variants here, so that we can rely on exhaustiveness checking when other variants get added?

Also it seems that as it is, if Candle returns a device type that we don't support, it would result in Cpu, which results in kernels that are not compatible with the device type?

Comment on lines +75 to +79
macro_rules! ptr {
($v:expr) => {
Ok(unsafe { $v.as_ptr().add(offset) as *mut c_void })
};
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove, make explicit.

Comment on lines +101 to +107
macro_rules! ptr {
($slice:expr) => {{
let view = $slice.slice(offset..);
let (device_ptr, _sync) = view.device_ptr(&stream);
Ok(device_ptr as *mut c_void)
}};
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think rather than a macro, this could be a trait + impl? At least I think with a generic type it should work with one implementation for all cases?

// Tensors are passed to the kernel as DLPack pointers directly into
// candle's storage - no copies for contiguous tensors.
pub trait CallKernel {
fn call(&self, func_name: &str, args: &[&Tensor]) -> Result<()>;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if there are non-tensor argument, e.g. option bools, epsilon floats, etc.?

Comment on lines +183 to +187
let device_type = match kind {
BackendKind::Cpu => tvm_ffi::DL_CPU,
BackendKind::Cuda => tvm_ffi::DL_CUDA,
BackendKind::Xpu => tvm_ffi::DL_ONEAPI,
};
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this could use a From implementation outside the function?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants