Skip to content

Refactor Wan Model Training & Add Wan-VACE Training Support#352

Open
ninatu wants to merge 9 commits intomainfrom
ninatu/wan_training
Open

Refactor Wan Model Training & Add Wan-VACE Training Support#352
ninatu wants to merge 9 commits intomainfrom
ninatu/wan_training

Conversation

@ninatu
Copy link
Collaborator

@ninatu ninatu commented Mar 11, 2026

This PR introduces several improvements and fixes to the Wan model training, as well as adds support for training Wan-VACE models.

Key changes include:

  1. Bug fixes:

    • Resolved training mode bug when dropout > 0 (e.g., ensured rngs parameter is passed to layer_forward for gradient checkpointing with dropout)
    • Fixed prepare_sample_fn usage for 'tfrecord' dataset type.
    • Addressed checkpoint loading issues with larger TPU slices and different topologies for Wan 2.1.
    • Corrected timestep sampling for continuous sampling
  2. Config updates:

    • Ensured adam_weight_decay is a float.
    • Added tensorboard_dir parameter for logging.
    • Now uses config.learning_rate instead of a hardcoded value.
    • Set default dropout to 0.0 in WAN configs (instead of 0.1).
  3. Wan-VACE Support:

    • Refactoring: Common training components (initialization, scheduler, TFLOPs calculation, training/eval loops) have been abstracted into a new BaseWanTrainer ABC to improve code structure and reusability.
    • Added new scripts (train_wan_vace.py), trainer (wan_vace_trainer.py), and checkpointing logic (wan_vace_checkpointing_2_1.py) to enable training of WAN-VACE models.
  4. New Features:

    • Introduced config.disable_training_weights to optionally disable mid-point loss weighting.
    • Added logging for max_grad_norm and max_abs_grad.

ninatu and others added 9 commits March 11, 2026 14:22
- Ensure `adam_weight_decay` is a float.
- Add `tensorboard_dir` parameter for logging.

Co-authored-by: martinarroyo <martinarroyo@google.com>
- Conditionally apply dropout only when rate > 0.
- Use standard list initialization.
- Add rngs parameter to layer_forward (essential for gradient checkpointing with dropout > 0)

Co-authored-by: martinarroyo <martinarroyo@google.com>
Replaces the hardcoded learning rate in the optimizer creation with the value from `config.learning_rate`.

Co-authored-by: martinarroyo <martinarroyo@google.com>
Co-authored-by: martinarroyo <martinarroyo@google.com>
Co-authored-by: martinarroyo <martinarroyo@google.com>
…lices and different topologies

Co-authored-by: martinarroyo <martinarroyo@google.com>
…ling and introduce disable_training_weights, add max_grad_norm and max_abs_grad logging.

- Switched timestamp sampling from discrete to continuous.
- Add max_grad_norm and max_abs_grad calculation and logging.
- Introduced `config.disable_training_weights` to optionally disable mid-point loss weighting.

Co-authored-by: martinarroyo <martinarroyo@google.com>
The following key functionalities have been moved from WanTrainer to the new `BaseWanTrainer` ABC:
- Initialization and config handling
- Scheduler creation
- TFLOPs calculation
- Core training and evaluation loops (`start_training`, `training_loop`, `eval`)
- Abstract methods for checkpointer, data loading, sharding, and step functions.

Co-authored-by: martinarroyo <martinarroyo@google.com>
Introduces training support for WAN-VACE models.

New files:
- train_wan_vace.py: Main training script.
- wan_vace_trainer.py: Trainer class for WAN-VACE.
- wan_vace_checkpointing_2_1.py: Checkpointing logic for WAN-VACE.

Co-authored-by: martinarroyo <martinarroyo@google.com>
@ninatu ninatu requested a review from entrpn as a code owner March 11, 2026 14:35
@github-actions
Copy link

# Output directory
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
base_output_directory: ""
tensorboard_dir: ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensorboard_dir is created automatically inside the pyconfig. Is there a reason it needs to be in the config?

@entrpn
Copy link
Collaborator

entrpn commented Mar 12, 2026

As this is a fairly large refactor:

  • @prishajain1 can you do a review of the checkpointing changes?

  • @susanbao can you take a quick look at the training changes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants