feat(nn): training options and in-memory batching#35
Draft
szvsw wants to merge 4 commits intofeature/create-refactored-surrogate-trainingfrom
Draft
feat(nn): training options and in-memory batching#35szvsw wants to merge 4 commits intofeature/create-refactored-surrogate-trainingfrom
szvsw wants to merge 4 commits intofeature/create-refactored-surrogate-trainingfrom
Conversation
…eshold Require validation loss to improve by at least min_delta vs the previous best before resetting early stopping patience. Default 0 preserves prior behavior (any strictly lower val loss counts). Co-authored-by: Sam Wolk <szvsw@users.noreply.github.com>
Trainer l1_penalty adds lambda * sum(|theta|) to the training MSE loss. L2-style regularization remains optimizer weight_decay. Validation uses plain MSE only. Co-authored-by: Sam Wolk <szvsw@users.noreply.github.com>
Stop when monotonic elapsed time from the first training batch exceeds the limit. Partial epochs are skipped without validation; best checkpoint and post-training flow unchanged. Co-authored-by: Sam Wolk <szvsw@users.noreply.github.com>
Shuffle each epoch with np.random.permutation over row indices; slice batches with drop-last semantics. Validation uses sequential tensor slices. Avoids DataLoader overhead for data already in memory. Co-authored-by: Sam Wolk <szvsw@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Extends the PyTorch NN surrogate backend with:
Early stopping improvement threshold —
early_stopping_min_deltaonNNTrainerConfig. Validation loss must improve by at least this amount versus the previous best to reset patience (default0.0: any strictly lower val loss counts).L1 regularization —
l1_penaltyonNNTrainerConfig. When positive, training loss is MSE plusl1_penalty * sum(|theta|)over trainable parameters. Default0.0disables it. Validation and early stopping still use plain MSE.Max training time —
max_training_minutesonNNTrainerConfig(defaultNone). When set, training stops once monotonic elapsed time from the start of the first training batch of the first epoch reaches that many minutes. The current epoch is abandoned without validation if the limit is hit mid-epoch. Usestime.monotonic().In-memory batching — Training no longer uses
DataLoader. Data stays as CPUtorchtensors built from NumPy; each epoch usesnp.random.permutation(n_samples)and batch indices withdrop_last-equivalent batch count (n // batch_size). Validation uses sequential tensor slices with the same drop-last rule and no shuffle.L2: Unchanged — Adam/SGD
weight_decayremains the L2-style term on the optimizer.Example (YAML)
Testing
uv run pytest— 5 passedNotes
-c commit.gpgsign=falsewhen the SSH signing helper hits a GLIBC mismatch in this environment.