From 388cd45eb8eaab89b5dd6f8013a9162370a41b7a Mon Sep 17 00:00:00 2001 From: Di Xu Date: Thu, 19 Mar 2026 09:02:25 -0700 Subject: [PATCH] Add merge() method to LoRALinear for zero-cost inference deployment Summary: Implements W_merged = W + (alpha/rank) * B @ A, allowing LoRA weights to be folded into the base linear layer at deployment time, eliminating additional inference latency per the LoRA paper (arxiv 2106.09685). Differential Revision: D97174451 --- examples/models/llama/lora.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/examples/models/llama/lora.py b/examples/models/llama/lora.py index 12c1c4e5d68..c84bd4b1140 100644 --- a/examples/models/llama/lora.py +++ b/examples/models/llama/lora.py @@ -40,9 +40,38 @@ def __init__( self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) + def merge(self) -> nn.Linear: + """Merge LoRA weights into base weight, returning a standard nn.Linear. + + W_merged = W + (alpha / rank) * B @ A + This eliminates the LoRA path at inference with zero additional latency. + """ + merged = nn.Linear(self.in_dim, self.out_dim, bias=self.use_bias) + merged.weight.data = self.weight + (self.alpha / self.rank) * ( + self.lora_b.weight @ self.lora_a.weight + ) + if self.use_bias: + merged.bias.data = self.bias.data.clone() + return merged + def forward(self, x: torch.Tensor) -> torch.Tensor: out = torch.nn.functional.linear(x, self.weight, self.bias) lora_out = self.lora_a(self.dropout(x)) lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) return out + lora_out + + +def merge_lora_weights(model: nn.Module) -> nn.Module: + """Replace all LoRALinear modules in the model with merged nn.Linear modules. + + Walks the module tree and substitutes each LoRALinear with a standard + nn.Linear whose weight is W + (alpha/rank) * B @ A. This eliminates + LoRA overhead at inference time. + """ + for name, module in model.named_modules(): + for attr_name, child in list(module.named_children()): + if isinstance(child, LoRALinear): + merged = child.merge() + setattr(module, attr_name, merged) + return model