diff --git a/examples/models/llama/lora.py b/examples/models/llama/lora.py index 12c1c4e5d68..c84bd4b1140 100644 --- a/examples/models/llama/lora.py +++ b/examples/models/llama/lora.py @@ -40,9 +40,38 @@ def __init__( self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) + def merge(self) -> nn.Linear: + """Merge LoRA weights into base weight, returning a standard nn.Linear. + + W_merged = W + (alpha / rank) * B @ A + This eliminates the LoRA path at inference with zero additional latency. + """ + merged = nn.Linear(self.in_dim, self.out_dim, bias=self.use_bias) + merged.weight.data = self.weight + (self.alpha / self.rank) * ( + self.lora_b.weight @ self.lora_a.weight + ) + if self.use_bias: + merged.bias.data = self.bias.data.clone() + return merged + def forward(self, x: torch.Tensor) -> torch.Tensor: out = torch.nn.functional.linear(x, self.weight, self.bias) lora_out = self.lora_a(self.dropout(x)) lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) return out + lora_out + + +def merge_lora_weights(model: nn.Module) -> nn.Module: + """Replace all LoRALinear modules in the model with merged nn.Linear modules. + + Walks the module tree and substitutes each LoRALinear with a standard + nn.Linear whose weight is W + (alpha/rank) * B @ A. This eliminates + LoRA overhead at inference time. + """ + for name, module in model.named_modules(): + for attr_name, child in list(module.named_children()): + if isinstance(child, LoRALinear): + merged = child.merge() + setattr(module, attr_name, merged) + return model