-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain_model.py
More file actions
106 lines (90 loc) · 2.79 KB
/
train_model.py
File metadata and controls
106 lines (90 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Script for training a model with or without differential privacy"""
import argparse
from pathlib import Path
import pytorch_lightning as pl
from dpsnn import SplitNN
def _get_model_savepath(root, args):
savename = "attack_validation_" if args.attack_val else ""
savename += (
f"{args.saveas}_{args.noise_scale}noise_{args.nopeek_weight}nopeek".replace(
".", ""
)
+ "_{epoch:02d}"
)
savepath = root / "models" / "classifiers" / savename
return savepath
def main(root, args):
savepath = _get_model_savepath(root, args)
print(f"Saving model to {savepath}")
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=savepath,
monitor="val_accuracy",
mode="max",
)
trainer = pl.Trainer(
max_epochs=args.max_epochs,
gpus=args.gpus,
checkpoint_callback=checkpoint_callback,
)
trainer.fit(model)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a SplitNN with differential privacy optionally applied to intermediate data"
)
parser.add_argument(
"--noise-scale",
type=float,
required=True,
help="Scale of laplacian noise from which to draw. If 0.0, no noise is added. Required.",
)
parser.add_argument(
"--nopeek-weight",
type=float,
required=True,
help="Weighting of nopeek loss term. Required.",
)
parser.add_argument(
"--batch-size", default=128, type=int, help="Batch size (default 128)"
)
parser.add_argument(
"--learning-rate",
default=1e-4,
type=float,
help="Starting learning rate (default 1e-4)",
)
parser.add_argument(
"--saveas",
default="mnist",
type=str,
help="Name of model to save as (default is 'mnist')."
"Note that '_{noisescale}noise' will be appended to the end of the name",
)
parser.add_argument(
"--overfit-pct",
default=0.0,
type=float,
help="Proportion of training data to use (default 0.0 [all data])",
)
parser.add_argument(
"--max-epochs",
type=int,
default=10,
help="Number of epoch to train for (default = 10)",
)
parser.add_argument(
"--gpus", default=None, help="Number of gpus to use (default None)"
)
parser.add_argument(
"--attack-val",
help="Provide this flag if training a model to validate attacker performance",
action="store_true",
dest="attack_val",
)
parser.set_defaults(attack_val=False)
args = parser.parse_args()
# File paths
project_root = Path(__file__).resolve().parents[1]
# ----- Model -----
model = SplitNN(args)
# ----- Train model -----
main(project_root, args)