-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgridSearch.py
More file actions
90 lines (82 loc) · 3.45 KB
/
gridSearch.py
File metadata and controls
90 lines (82 loc) · 3.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import json
import os
from random import uniform
import argparse
from train_ae import train_example
from utils.dataset import ShapeNetDataset
from visualization_tools import printPointCloud as ptPC
import torch
import sys
from sklearn.model_selection import ParameterGrid
import numpy as np
import gc
def optimize_params(filepath=os.path.join("parameters", "params.json")):
"""
:param filepath: string: json file path (contains ALL the hyperparameters, also those fixed: see
lr_params.json for reference).
N.B.: it should not contain the hyperparameters passed through default_params
:return:
"""
json_params = json.loads(open(filepath).read())
parser = argparse.ArgumentParser(description=f'Model validation')
args = parser.parse_args()
best_val_loss = sys.float_info.max
dict_params = {}
best_hyperparams = {} # contains the best hyperparameters (only those randomly generated) {hyperparam: value, ...}
current_hyperparams = {} # contains the best hyperparameters (only those randomly generated)
param_grid = {}
for hyperparam, value in json_params.items():
# check if 'value' is a list
if isinstance(value, list):
param_grid[hyperparam] = value
else:
if value == 'None':
value = None
setattr(args, hyperparam, value)
test_dataset = ShapeNetDataset(
root=args.dataset,
split='test',
class_choice=args.test_class_choice,
npoints=1024)
counter = 0
for current_param_grid in ParameterGrid(param_grid):
for hyperparam, value in current_param_grid.items():
setattr(args, hyperparam, value)
current_hyperparams[hyperparam] = value
setattr(args, 'runNumber', counter)
print(f"\n\n------------------------------------------------------------------\nParameters: {args}\n")
# val_losses is the list of losses obtained during validation
model, val_losses = train_example(args)
if min(val_losses) < best_val_loss:
print(f"--- Best validation loss found! {min(val_losses)} (previous one: {best_val_loss}), corresponding to "
f"hyperparameters {current_hyperparams.items()}")
best_val_loss = min(val_losses)
best_hyperparams = current_hyperparams
ptPC.print_original_decoded_point_clouds(test_dataset, args.test_class_choice, model, args)
#dict_params[hash(str(args))] = str(args)
dict_params[counter] = str(args)
counter = counter + 1
folder = args.outf
try:
os.makedirs(folder)
except OSError:
pass
with open(os.path.join(folder, f'params_dictionary.json'), 'w') as f:
json.dump(dict_params, f)
del model
gc.collect()
torch.cuda.empty_cache()
return best_hyperparams
# def wrapper_train_by_class(json_path=os.path.join("parameters", "params.json")):
# json_params = json.loads(open(json_path).read())
# parser = argparse.ArgumentParser(description=f'Training models with different classes')
# args = parser.parse_args()
# for hyperparam, value in json_params.items():
# if value == 'None':
# value = None
# setattr(args, hyperparam, value)
# train_model_by_class(args)
if __name__ == '__main__':
best_params = optimize_params(filepath=os.path.join("parameters", "gridsearch_table_parameters.json"))
#print(f"Best parameters: \t{best_params}\n")
# wrapper_train_by_class()