diff --git a/litebird_sim/non_linearity.py b/litebird_sim/non_linearity.py index dbea2fd1..1636b91d 100644 --- a/litebird_sim/non_linearity.py +++ b/litebird_sim/non_linearity.py @@ -1,8 +1,9 @@ -import numpy as np from dataclasses import dataclass +import numpy as np +from numpy.random import PCG64, Generator, SeedSequence + from .observations import Observation -from .seeding import regenerate_or_check_detector_generators @dataclass @@ -123,7 +124,6 @@ def apply_quadratic_nonlin_to_observations( nl_params: NonLinParams | None = None, component: str = "tod", user_seed: int | None = None, - dets_random: list[np.random.Generator] | None = None, ): """ Apply a quadratic nonlinearity to some time-ordered data @@ -155,23 +155,15 @@ def apply_quadratic_nonlin_to_observations( component : str, optional Name of the TOD attribute to modify. Defaults to `"tod"`. user_seed : int, optional - Base seed to build the RNG hierarchy and generate detector-level RNGs - that overwrite any eventual `dets_random`. Required if `dets_random` - is not provided. - dets_random : list of np.random.Generator, optional - List of per-detector random number generators. If not provided, and - `user_seed` is given, generators are created internally. One of - `user_seed` or `dets_random` must be provided. + Base seed to build the RNG that overwrite the default generators + defined from the detector names. Raises ------ TypeError If `observations` is neither an `Observation` nor a list of them. - ValueError - If neither `user_seed` nor `dets_random` is provided. - AssertionError - If the number of random generators does not match the number of detectors. """ + if nl_params is None: nl_params = NonLinParams() @@ -183,14 +175,19 @@ def apply_quadratic_nonlin_to_observations( raise TypeError( "The parameter `observations` must be an `Observation` or a list of `Observation`." ) - dets_random = regenerate_or_check_detector_generators( - observations=obs_list, - user_seed=user_seed, - dets_random=dets_random, - ) # iterate through each observation for cur_obs in obs_list: + det_names = cur_obs.name + if user_seed is None: + seeds = [sum(ord(c) for c in dn) for dn in det_names] + sg = SeedSequence(seeds) + dets_random = [Generator(PCG64(s)) for s in sg.spawn(cur_obs.n_detectors)] + + else: + sg = SeedSequence(user_seed) + dets_random = [Generator(PCG64(s)) for s in sg.spawn(cur_obs.n_detectors)] + tod = getattr(cur_obs, component) apply_quadratic_nonlin( diff --git a/litebird_sim/simulations.py b/litebird_sim/simulations.py index df2a3cb1..974541a2 100644 --- a/litebird_sim/simulations.py +++ b/litebird_sim/simulations.py @@ -33,15 +33,17 @@ add_convolved_sky_to_observations, ) from .beam_synthesis import generate_gauss_beam_alms +from .constants import NUMBA_NUM_THREADS_ENVVAR from .coordinates import CoordinateSystem -from .detectors import DetectorInfo, FreqChannelInfo, InstrumentInfo, UUID +from .detectors import UUID, DetectorInfo, FreqChannelInfo, InstrumentInfo from .dipole import DipoleType, add_dipole_to_observations from .distribute import distribute_evenly, distribute_optimally -from .gaindrifts import GainDriftType, GainDriftParams, apply_gaindrift_to_observations +from .gaindrifts import GainDriftParams, GainDriftType, apply_gaindrift_to_observations from .healpix import write_healpix_map_to_file from .hwp import HWP from .hwp_diff_emiss import add_2f_to_observations from .imo.imo import Imo +from .input_sky import SkyGenerationParams, SkyGenerator from .io import read_list_of_observations, write_list_of_observations from .mapmaking import ( BinnerResult, @@ -54,8 +56,8 @@ make_destriped_map, save_destriper_results, ) -from .input_sky import SkyGenerator, SkyGenerationParams -from .mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID +from .maps_and_harmonics import HealpixMap, SphericalHarmonics +from .mpi import MPI_COMM_GRID, MPI_COMM_WORLD, MPI_ENABLED from .noise import add_noise_to_observations from .non_linearity import NonLinParams, apply_quadratic_nonlin_to_observations from .observations import Observation, TodDescription @@ -65,13 +67,13 @@ from .scanning import ScanningStrategy, SpinningScanningStrategy from .seeding import RNGHierarchy from .spacecraft import SpacecraftOrbit, spacecraft_pos_and_vel -from .maps_and_harmonics import SphericalHarmonics, HealpixMap from .units import Units from .version import ( - __version__ as litebird_sim_version, __author__ as litebird_sim_author, ) -from .constants import NUMBA_NUM_THREADS_ENVVAR +from .version import ( + __version__ as litebird_sim_version, +) DEFAULT_BASE_IMO_URL = "https://litebirdimo.ssdc.asi.it" @@ -2078,7 +2080,6 @@ def apply_quadratic_nonlin( user_seed: int | None = None, component: str = "tod", append_to_report: bool = False, - rng_hierarchy: RNGHierarchy | None = None, ): """A method to apply non-linearity to the observation. @@ -2087,11 +2088,7 @@ def apply_quadratic_nonlin( applies non-linearity to a list of :class:`.Observation` instance. Random number generators are obtained from the detector-level layer. As default it uses the `dets_random` field of a :class:`.Simulation` object for this. """ - if rng_hierarchy is None: - rng_hierarchy = self.rng_hierarchy - dets_random = rng_hierarchy.get_detector_level_generators_on_rank( - self.mpi_comm.rank - ) + if nl_params is None: nl_params = NonLinParams() @@ -2100,7 +2097,6 @@ def apply_quadratic_nonlin( nl_params=nl_params, user_seed=user_seed, component=component, - dets_random=dets_random, ) if append_to_report and MPI_COMM_WORLD.rank == 0: diff --git a/test/test_nonlinearity.py b/test/test_nonlinearity.py index 86f4c226..aa050690 100644 --- a/test/test_nonlinearity.py +++ b/test/test_nonlinearity.py @@ -1,7 +1,9 @@ from copy import deepcopy -import numpy as np + import litebird_sim as lbs +import numpy as np from astropy.time import Time +from numpy.random import PCG64, Generator, SeedSequence def test_add_quadratic_nonlinearity(): @@ -16,7 +18,6 @@ def test_add_quadratic_nonlinearity(): ] nl_params = lbs.NonLinParams(sampling_gaussian_loc=0.0, sampling_gaussian_scale=0.1) - random_seed = 12345 sim = lbs.Simulation( base_path="nonlin_example", @@ -41,23 +42,18 @@ def test_add_quadratic_nonlinearity(): component="nl_2_self", ) - # Applying non-linearity on the given TOD component of an `Observation` object - RNG_hierarchy = lbs.RNGHierarchy( - random_seed, comm_size=1, num_detectors_per_rank=len(dets) - ) - dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(0) lbs.apply_quadratic_nonlin_to_observations( observations=sim.observations, nl_params=nl_params, component="nl_2_obs", - dets_random=dets_random, ) # Applying non-linearity on the TOD arrays of the individual detectors. - RNG_hierarchy = lbs.RNGHierarchy( - random_seed, comm_size=1, num_detectors_per_rank=len(dets) - ) - dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(0) + seeds = [sum(ord(c) for c in dn) for dn in sim.observations[0].name] + sg = SeedSequence(seeds) + dets_random = [ + Generator(PCG64(s)) for s in sg.spawn(sim.observations[0].n_detectors) + ] for idx, tod in enumerate(sim.observations[0].nl_2_det): lbs.apply_quadratic_nonlin_for_one_detector( tod_det=tod, @@ -75,10 +71,11 @@ def test_add_quadratic_nonlinearity(): ) # Check if non-linearity is applied correctly - RNG_hierarchy = lbs.RNGHierarchy( - random_seed, comm_size=1, num_detectors_per_rank=len(dets) - ) - dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(0) + seeds = [sum(ord(c) for c in dn) for dn in sim.observations[0].name] + sg = SeedSequence(seeds) + dets_random = [ + Generator(PCG64(s)) for s in sg.spawn(sim.observations[0].n_detectors) + ] sim.observations[0].tod_origin = np.ones_like(sim.observations[0].tod) for idx, tod in enumerate(sim.observations[0].nl_2_det): g_one_over_k = dets_random[idx].normal(