feat: add rust kernels library for loading kernels#421
Conversation
| pub fn detect_cuda_version() -> Option<String> { | ||
| cuda_version_from_smi().or_else(cuda_version_from_nvcc) | ||
| } | ||
|
|
||
| fn cuda_version_from_smi() -> Option<String> { | ||
| let output = Command::new("nvidia-smi").output().ok()?; | ||
| if !output.status.success() { | ||
| return None; | ||
| } | ||
| let stdout = String::from_utf8_lossy(&output.stdout); | ||
| let rest = stdout.split("CUDA Version:").nth(1)?; | ||
| Some(rest.split_whitespace().next()?.to_string()) | ||
| } | ||
|
|
||
| fn cuda_version_from_nvcc() -> Option<String> { | ||
| let output = Command::new("nvcc").arg("--version").output().ok()?; | ||
| let stdout = String::from_utf8_lossy(&output.stdout); | ||
| let after = stdout.split("release ").nth(1)?; | ||
| Some(after.split(',').next()?.trim().to_string()) | ||
| } |
There was a problem hiding this comment.
This may not be the same as the library that a framework is compiled against and dynamically loads. Also, nvidia-smi gives the driver library version, not the CUDA runtime version. We need to get it from cudart, e.g. see:
kernels/kernels/src/kernels/backends.py
Line 254 in 8ed7bb4
libloading seems to be the most widely used library for dlopen:
| pub fn candle_device(self) -> Result<Device> { | ||
| match self { | ||
| BackendKind::Cpu => Ok(Device::Cpu), | ||
| #[cfg(feature = "candle-cuda")] | ||
| BackendKind::Cuda => Device::new_cuda(0).map_err(Into::into), | ||
| #[cfg(not(feature = "candle-cuda"))] | ||
| BackendKind::Cuda => Ok(Device::Cpu), | ||
| BackendKind::Xpu => Ok(Device::Cpu), | ||
| } | ||
| } |
There was a problem hiding this comment.
I think this can be TryFrom<BackendDevice> for Device. Not 100% sure if it works with the coherency rules, since it's a different mod in the same crate. But I think it should.
| pub fn candle_supported(self) -> Self { | ||
| match self { | ||
| #[cfg(feature = "candle-cuda")] | ||
| BackendKind::Cuda => BackendKind::Cuda, | ||
| #[cfg(not(feature = "candle-cuda"))] | ||
| BackendKind::Cuda => BackendKind::Cpu, | ||
| other => other, | ||
| } | ||
| } |
There was a problem hiding this comment.
The function name is not very descriptive, maybe to_candle_supported?
| #[allow(unreachable_patterns)] | ||
| _ => BackendKind::Cpu, |
There was a problem hiding this comment.
I think it would be better to explicitly enumerate the other variants here, so that we can rely on exhaustiveness checking when other variants get added?
Also it seems that as it is, if Candle returns a device type that we don't support, it would result in Cpu, which results in kernels that are not compatible with the device type?
| macro_rules! ptr { | ||
| ($v:expr) => { | ||
| Ok(unsafe { $v.as_ptr().add(offset) as *mut c_void }) | ||
| }; | ||
| } |
| macro_rules! ptr { | ||
| ($slice:expr) => {{ | ||
| let view = $slice.slice(offset..); | ||
| let (device_ptr, _sync) = view.device_ptr(&stream); | ||
| Ok(device_ptr as *mut c_void) | ||
| }}; | ||
| } |
There was a problem hiding this comment.
I think rather than a macro, this could be a trait + impl? At least I think with a generic type it should work with one implementation for all cases?
| // Tensors are passed to the kernel as DLPack pointers directly into | ||
| // candle's storage - no copies for contiguous tensors. | ||
| pub trait CallKernel { | ||
| fn call(&self, func_name: &str, args: &[&Tensor]) -> Result<()>; |
There was a problem hiding this comment.
What if there are non-tensor argument, e.g. option bools, epsilon floats, etc.?
| let device_type = match kind { | ||
| BackendKind::Cpu => tvm_ffi::DL_CPU, | ||
| BackendKind::Cuda => tvm_ffi::DL_CUDA, | ||
| BackendKind::Xpu => tvm_ffi::DL_ONEAPI, | ||
| }; |
There was a problem hiding this comment.
Seems like this could use a From implementation outside the function?
This PR adds a new client library for loading hf kernels in rust. This allow tvmffi based kernels to be called from rust and optionally integrates with candle for a better tensor ux.
Example usage with candle
repo with candle and non candle examples https://github.com/drbh/hf-kernels-rust