Skip to content

Fix missing activation checkpointing (recompute) parameters in bridge mode#1833

Open
XJL010622 wants to merge 2 commits intoTHUDM:mainfrom
XJL010622:fix-bridge-recompute
Open

Fix missing activation checkpointing (recompute) parameters in bridge mode#1833
XJL010622 wants to merge 2 commits intoTHUDM:mainfrom
XJL010622:fix-bridge-recompute

Conversation

@XJL010622
Copy link
Copy Markdown

Motivation

When using megatron_to_hf_mode == "bridge", the AutoBridge.from_hf_pretrained() method generates a model provider based strictly on the HuggingFace config.json. However, HF configurations only define the static model architecture and do not contain training-specific memory optimization arguments such as activation checkpointing (recompute).

Consequently, critical arguments like recompute_granularity are lost during the provider initialization. This causes activation checkpointing to fail silently, leading to unexpected and severe OOM (Out of Memory) errors during training, especially for large models or long context windows.

Modifications

This PR explicitly synchronizes the recompute-related parameters from the command-line args to the provider before provider.finalize() is called.

We use a safe iteration over hasattr(args, ...) to ensure compatibility even if certain recompute arguments are not passed in the specific launch script.

Changed Code snippet (for review)

In get_model_provider_func (inside the bridge conditional branch):

        provider.variable_seq_lengths = args.variable_seq_lengths
        if hasattr(args, "moe_token_dispatcher_type"):
            provider.moe_token_dispatcher_type = args.moe_token_dispatcher_type

        # --- NEW CODE ADDED HERE ---
        # Explicitly sync activation checkpointing parameters since HF config does not contain them
        recompute_fields = (
            "recompute_granularity",
            "recompute_method",
            "recompute_num_layers"
        )
        for field in recompute_fields:
            if hasattr(args, field) and getattr(args, field) is not None:
                setattr(provider, field, getattr(args, field))
        # ---------------------------

        if getattr(args, "decoder_first_pipeline_num_layers", None) is not None:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant