Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,446 changes: 3,446 additions & 0 deletions kernels-rs/Cargo.lock

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions kernels-rs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
name = "kernels"
version = "0.1.0"
edition = "2024"
description = "Load and call Hugging Face Hub kernels in Rust"
homepage = "https://github.com/huggingface/kernels"
license = "Apache-2.0"
repository = "https://github.com/huggingface/kernels"

[features]
default = []
candle = ["dep:candle-core"]
candle-cuda = ["candle", "candle-core/cuda", "dep:cudarc"]

[dependencies]
candle-core = { version = "0.10.0", optional = true }
cudarc = { version = "0.19.0", optional = true }
huggingface-hub = { git = "https://github.com/huggingface/huggingface_hub_rust.git", rev = "8cbc662035e04d4be8e829316272893e980f5926", package = "huggingface-hub", features = ["blocking"] }
libc = "0.2"
libloading = "0.8"
thiserror = "1"
walkdir = "2"
118 changes: 118 additions & 0 deletions kernels-rs/src/backend.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use std::fmt;
use std::str::FromStr;

use libloading::Library;

use crate::error::Error;

#[derive(Debug, Clone)]
pub enum Backend {
Cpu,
Cuda { version: String },
Xpu { version: String },
}

impl Backend {
pub fn kind(&self) -> BackendKind {
match self {
Backend::Cpu => BackendKind::Cpu,
Backend::Cuda { .. } => BackendKind::Cuda,
Backend::Xpu { .. } => BackendKind::Xpu,
}
}

pub fn name(&self) -> &str {
self.kind().as_str()
}
}

impl fmt::Display for Backend {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Backend::Cpu => write!(f, "cpu"),
Backend::Cuda { version } => write!(f, "cuda {version}"),
Backend::Xpu { version } => write!(f, "xpu {version}"),
}
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BackendKind {
Cpu,
Cuda,
Xpu,
}

impl BackendKind {
pub fn as_str(self) -> &'static str {
match self {
BackendKind::Cpu => "cpu",
BackendKind::Cuda => "cuda",
BackendKind::Xpu => "xpu",
}
}
}

impl fmt::Display for BackendKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}

impl FromStr for BackendKind {
type Err = Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"cpu" => Ok(Self::Cpu),
"cuda" => Ok(Self::Cuda),
"xpu" => Ok(Self::Xpu),
other => Err(Error::Kernel(format!("unknown backend: {other}"))),
}
}
}

pub fn detect_cuda_version() -> Option<String> {
type CudaRuntimeGetVersion = unsafe extern "C" fn(*mut i32) -> i32;

let library = unsafe { Library::new(libloading::library_filename("cudart")) }.ok()?;
let cuda_runtime_get_version: libloading::Symbol<CudaRuntimeGetVersion> =
unsafe { library.get(b"cudaRuntimeGetVersion\0").ok()? };

let mut runtime_version = 0;
if unsafe { cuda_runtime_get_version(&mut runtime_version) } != 0 {
return None;
}

format_cuda_runtime_version(runtime_version)
}

fn format_cuda_runtime_version(runtime_version: i32) -> Option<String> {
if runtime_version <= 0 {
return None;
}

let major = runtime_version / 1000;
let minor = (runtime_version % 1000) / 10;
Some(format!("{major}.{minor}"))
}
Comment on lines +75 to +98
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not be the same as the library that a framework is compiled against and dynamically loads. Also, nvidia-smi gives the driver library version, not the CUDA runtime version. We need to get it from cudart, e.g. see:

def _get_cuda() -> Optional[CUDA]:

libloading seems to be the most widely used library for dlopen:

https://github.com/nagisa/rust_libloading/

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, I've updated to prefer querying via cudaRuntimeGetVersion from cudart in the latest changes. I've tested locally and am not running into any issues - however I'm not 100% sure if we need more logic to search for the cudart like ctypes.util.find_library("cudart") does if its not in the default location


pub fn detect() -> BackendKind {
if detect_cuda_version().is_some() {
BackendKind::Cuda
} else {
BackendKind::Cpu
}
}

#[cfg(test)]
mod tests {
use super::format_cuda_runtime_version;

#[test]
fn formats_cuda_runtime_versions() {
assert_eq!(format_cuda_runtime_version(12080).as_deref(), Some("12.8"));
assert_eq!(format_cuda_runtime_version(11020).as_deref(), Some("11.2"));
assert_eq!(format_cuda_runtime_version(0), None);
}
}
Loading
Loading