Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions examples/models/llama/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,38 @@
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)

def merge(self) -> nn.Linear:
"""Merge LoRA weights into base weight, returning a standard nn.Linear.

W_merged = W + (alpha / rank) * B @ A
This eliminates the LoRA path at inference with zero additional latency.
"""
merged = nn.Linear(self.in_dim, self.out_dim, bias=self.use_bias)
merged.weight.data = self.weight + (self.alpha / self.rank) * (
self.lora_b.weight @ self.lora_a.weight
)
if self.use_bias:
merged.bias.data = self.bias.data.clone()
return merged

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.nn.functional.linear(x, self.weight, self.bias)
lora_out = self.lora_a(self.dropout(x))
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)

return out + lora_out


def merge_lora_weights(model: nn.Module) -> nn.Module:
"""Replace all LoRALinear modules in the model with merged nn.Linear modules.

Walks the module tree and substitutes each LoRALinear with a standard
nn.Linear whose weight is W + (alpha/rank) * B @ A. This eliminates
LoRA overhead at inference time.
"""
for name, module in model.named_modules():

Check warning on line 72 in examples/models/llama/lora.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 B007

Loop control variable 'name' not used within the loop body. If this is intended, start the name with an underscore. See https://github.com/PyCQA/flake8-bugbear#list-of-warnings.
for attr_name, child in list(module.named_children()):
if isinstance(child, LoRALinear):
merged = child.merge()
setattr(module, attr_name, merged)
return model
Loading