diff --git a/src/pipt/loop/assimilation.py b/src/pipt/loop/assimilation.py index 76f26fb..d280547 100644 --- a/src/pipt/loop/assimilation.py +++ b/src/pipt/loop/assimilation.py @@ -475,13 +475,10 @@ def post_process_forecast(self): vintage = 0 # Store according to sparse_info - if vintage < len(self.ensemble.sparse_info['mask']) and \ - pred_data[key].shape[0] == int(np.sum(self.ensemble.sparse_info['mask'][vintage])): - + if key == self.ensemble.sparse_info['compress_data'] and pred_data[key] is not None: # If first entry in pred_data_tmp if pred_data_tmp[i] is None: pred_data_tmp[i] = {key: pred_data[key]} - else: pred_data_tmp[i][key] = pred_data[key] @@ -516,19 +513,18 @@ def post_process_forecast(self): self.ensemble.data_rec = [] for i in range(len(pred_data_tmp)): # INDEX if pred_data_tmp[i] is not None: - for k in pred_data_tmp[i]: # DATATYPE - if vintage < len(self.ensemble.sparse_info['mask']) and \ - len(pred_data_tmp[i][k]) == int(np.sum(self.ensemble.sparse_info['mask'][vintage])): + for key in pred_data_tmp[i]: # DATATYPE + if key == self.ensemble.sparse_info['compress_data']: if self.ensemble.keys_da['daalg'][1] == 'gies': - self.ensemble.pred_data[i][k] = np.zeros( - (len(self.ensemble.obs_data[i][k]), self.ensemble.ne+1)) + self.ensemble.pred_data[i][key] = np.zeros( + (len(self.ensemble.obs_data[i][key]), self.ensemble.ne+1)) else: - self.ensemble.pred_data[i][k] = np.zeros( - (len(self.ensemble.obs_data[i][k]), self.ensemble.ne)) - for m in range(pred_data_tmp[i][k].shape[1]): - data_array = self.ensemble.compress_manager(pred_data_tmp[i][k][:, m], vintage, + self.ensemble.pred_data[i][key] = np.zeros( + (len(self.ensemble.obs_data[i][key]), self.ensemble.ne)) + for m in range(pred_data_tmp[i][key].shape[1]): + data_array = self.ensemble.compress_manager(pred_data_tmp[i][key][:, m], vintage, self.ensemble.sparse_info['use_ensemble']) - self.ensemble.pred_data[i][k][:, m] = data_array + self.ensemble.pred_data[i][key][:, m] = data_array vintage = vintage + 1 if self.ensemble.sparse_info['use_ensemble']: self.ensemble.compress_manager() diff --git a/src/pipt/loop/ensemble.py b/src/pipt/loop/ensemble.py index 2e6b53e..1677568 100644 --- a/src/pipt/loop/ensemble.py +++ b/src/pipt/loop/ensemble.py @@ -278,10 +278,8 @@ def _org_obs_data(self): load_data = np.load(truedata[i][0]) # Load the .npz file data_array = load_data[load_data.files[0]] - # Perform compression if required (we only and always compress signals with same size as number of active cells) - if self.sparse_info is not None and \ - vintage < len(self.sparse_info['mask']) and \ - len(data_array) == int(np.sum(self.sparse_info['mask'][vintage])): + # Perform compression for the data type specified in self.sparse_info['compress_data'] if required + if self.sparse_info is not None and self.keys_da['datatype'][0] == self.sparse_info['compress_data']: data_array = self.compress_manager(data_array, vintage, False) vintage = vintage + 1 @@ -306,16 +304,14 @@ def _org_obs_data(self): self.obs_data[i][self.keys_da['datatype'][0]] = np.array( truedata[i][:]) # no need to make this into a list else: - for j in range(len(self.keys_da['datatype'])): # DATATYPE + for j, datatype in enumerate(self.keys_da['datatype']): # Load a Numpy npz file if isinstance(truedata[i][j], str) and truedata[i][j].endswith('.npz'): load_data = np.load(truedata[i][j]) # Load the .npz file data_array = load_data[load_data.files[0]] - # Perform compression if required (we only and always compress signals with same size as number of active cells) - if self.sparse_info is not None and \ - vintage < len(self.sparse_info['mask']) and \ - len(data_array) == int(np.sum(self.sparse_info['mask'][vintage])): + # Perform compression for the data type specified in self.sparse_info['compress_data'] if required + if self.sparse_info is not None and datatype == self.sparse_info['compress_data']: data_array = self.compress_manager(data_array, vintage, False) vintage = vintage + 1 @@ -521,8 +517,7 @@ def _org_data_var(self): # Handle case when noise is estimated using wavelets if self.sparse_info is not None and self.datavar[i][datatype[j]] is not None and \ - vintage < len(self.sparse_info['mask']) and \ - len(self.datavar[i][datatype[j]]) == int(np.sum(self.sparse_info['mask'][vintage])): + datatype[j]==self.sparse_info['compress_data']: # compute var from sparse_data est_noise = np.power(self.sparse_data[vintage].est_noise, 2) self.datavar[i][datatype[j]] = est_noise # override the given value diff --git a/src/pipt/misc_tools/extract_tools.py b/src/pipt/misc_tools/extract_tools.py index 4cb2326..e62a0f9 100644 --- a/src/pipt/misc_tools/extract_tools.py +++ b/src/pipt/misc_tools/extract_tools.py @@ -444,6 +444,7 @@ def organize_sparse_representation(info: Union[dict,list]) -> dict: sparse['mask'].append(mask.flatten()) # Read rest of keywords + sparse['compress_data'] = info.get('compress_data', False) sparse['level'] = info['level'] sparse['wname'] = info['wname'] sparse['threshold_rule'] = info['threshold_rule']