Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions src/pipt/loop/assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,13 +475,10 @@ def post_process_forecast(self):
vintage = 0

# Store according to sparse_info
if vintage < len(self.ensemble.sparse_info['mask']) and \
pred_data[key].shape[0] == int(np.sum(self.ensemble.sparse_info['mask'][vintage])):

if key == self.ensemble.sparse_info['compress_data'] and pred_data[key] is not None:
# If first entry in pred_data_tmp
if pred_data_tmp[i] is None:
pred_data_tmp[i] = {key: pred_data[key]}

else:
pred_data_tmp[i][key] = pred_data[key]

Expand Down Expand Up @@ -516,19 +513,18 @@ def post_process_forecast(self):
self.ensemble.data_rec = []
for i in range(len(pred_data_tmp)): # INDEX
if pred_data_tmp[i] is not None:
for k in pred_data_tmp[i]: # DATATYPE
if vintage < len(self.ensemble.sparse_info['mask']) and \
len(pred_data_tmp[i][k]) == int(np.sum(self.ensemble.sparse_info['mask'][vintage])):
for key in pred_data_tmp[i]: # DATATYPE
if key == self.ensemble.sparse_info['compress_data']:
if self.ensemble.keys_da['daalg'][1] == 'gies':
self.ensemble.pred_data[i][k] = np.zeros(
(len(self.ensemble.obs_data[i][k]), self.ensemble.ne+1))
self.ensemble.pred_data[i][key] = np.zeros(
(len(self.ensemble.obs_data[i][key]), self.ensemble.ne+1))
else:
self.ensemble.pred_data[i][k] = np.zeros(
(len(self.ensemble.obs_data[i][k]), self.ensemble.ne))
for m in range(pred_data_tmp[i][k].shape[1]):
data_array = self.ensemble.compress_manager(pred_data_tmp[i][k][:, m], vintage,
self.ensemble.pred_data[i][key] = np.zeros(
(len(self.ensemble.obs_data[i][key]), self.ensemble.ne))
for m in range(pred_data_tmp[i][key].shape[1]):
data_array = self.ensemble.compress_manager(pred_data_tmp[i][key][:, m], vintage,
self.ensemble.sparse_info['use_ensemble'])
self.ensemble.pred_data[i][k][:, m] = data_array
self.ensemble.pred_data[i][key][:, m] = data_array
vintage = vintage + 1
if self.ensemble.sparse_info['use_ensemble']:
self.ensemble.compress_manager()
Expand Down
17 changes: 6 additions & 11 deletions src/pipt/loop/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,8 @@ def _org_obs_data(self):
load_data = np.load(truedata[i][0]) # Load the .npz file
data_array = load_data[load_data.files[0]]

# Perform compression if required (we only and always compress signals with same size as number of active cells)
if self.sparse_info is not None and \
vintage < len(self.sparse_info['mask']) and \
len(data_array) == int(np.sum(self.sparse_info['mask'][vintage])):
# Perform compression for the data type specified in self.sparse_info['compress_data'] if required
if self.sparse_info is not None and self.keys_da['datatype'][0] == self.sparse_info['compress_data']:
data_array = self.compress_manager(data_array, vintage, False)
vintage = vintage + 1

Expand All @@ -306,16 +304,14 @@ def _org_obs_data(self):
self.obs_data[i][self.keys_da['datatype'][0]] = np.array(
truedata[i][:]) # no need to make this into a list
else:
for j in range(len(self.keys_da['datatype'])): # DATATYPE
for j, datatype in enumerate(self.keys_da['datatype']):
# Load a Numpy npz file
if isinstance(truedata[i][j], str) and truedata[i][j].endswith('.npz'):
load_data = np.load(truedata[i][j]) # Load the .npz file
data_array = load_data[load_data.files[0]]

# Perform compression if required (we only and always compress signals with same size as number of active cells)
if self.sparse_info is not None and \
vintage < len(self.sparse_info['mask']) and \
len(data_array) == int(np.sum(self.sparse_info['mask'][vintage])):
# Perform compression for the data type specified in self.sparse_info['compress_data'] if required
if self.sparse_info is not None and datatype == self.sparse_info['compress_data']:
data_array = self.compress_manager(data_array, vintage, False)
vintage = vintage + 1

Expand Down Expand Up @@ -521,8 +517,7 @@ def _org_data_var(self):

# Handle case when noise is estimated using wavelets
if self.sparse_info is not None and self.datavar[i][datatype[j]] is not None and \
vintage < len(self.sparse_info['mask']) and \
len(self.datavar[i][datatype[j]]) == int(np.sum(self.sparse_info['mask'][vintage])):
datatype[j]==self.sparse_info['compress_data']:
# compute var from sparse_data
est_noise = np.power(self.sparse_data[vintage].est_noise, 2)
self.datavar[i][datatype[j]] = est_noise # override the given value
Expand Down
1 change: 1 addition & 0 deletions src/pipt/misc_tools/extract_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def organize_sparse_representation(info: Union[dict,list]) -> dict:
sparse['mask'].append(mask.flatten())

# Read rest of keywords
sparse['compress_data'] = info.get('compress_data', False)
sparse['level'] = info['level']
sparse['wname'] = info['wname']
sparse['threshold_rule'] = info['threshold_rule']
Expand Down
Loading