Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions litebird_sim/non_linearity.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
from dataclasses import dataclass

import numpy as np
from numpy.random import PCG64, Generator, SeedSequence

from .observations import Observation
from .seeding import regenerate_or_check_detector_generators


@dataclass
Expand Down Expand Up @@ -123,7 +124,6 @@ def apply_quadratic_nonlin_to_observations(
nl_params: NonLinParams | None = None,
component: str = "tod",
user_seed: int | None = None,
dets_random: list[np.random.Generator] | None = None,
):
"""
Apply a quadratic nonlinearity to some time-ordered data
Expand Down Expand Up @@ -155,23 +155,15 @@ def apply_quadratic_nonlin_to_observations(
component : str, optional
Name of the TOD attribute to modify. Defaults to `"tod"`.
user_seed : int, optional
Base seed to build the RNG hierarchy and generate detector-level RNGs
that overwrite any eventual `dets_random`. Required if `dets_random`
is not provided.
dets_random : list of np.random.Generator, optional
List of per-detector random number generators. If not provided, and
`user_seed` is given, generators are created internally. One of
`user_seed` or `dets_random` must be provided.
Base seed to build the RNG that overwrite the default generators
defined from the detector names.

Raises
------
TypeError
If `observations` is neither an `Observation` nor a list of them.
ValueError
If neither `user_seed` nor `dets_random` is provided.
AssertionError
If the number of random generators does not match the number of detectors.
"""

if nl_params is None:
nl_params = NonLinParams()

Expand All @@ -183,14 +175,19 @@ def apply_quadratic_nonlin_to_observations(
raise TypeError(
"The parameter `observations` must be an `Observation` or a list of `Observation`."
)
dets_random = regenerate_or_check_detector_generators(
observations=obs_list,
user_seed=user_seed,
dets_random=dets_random,
)

# iterate through each observation
for cur_obs in obs_list:
det_names = cur_obs.name
if user_seed is None:
seeds = [sum(ord(c) for c in dn) for dn in det_names]
sg = SeedSequence(seeds)
dets_random = [Generator(PCG64(s)) for s in sg.spawn(cur_obs.n_detectors)]

else:
sg = SeedSequence(user_seed)
dets_random = [Generator(PCG64(s)) for s in sg.spawn(cur_obs.n_detectors)]

tod = getattr(cur_obs, component)

apply_quadratic_nonlin(
Expand Down
24 changes: 10 additions & 14 deletions litebird_sim/simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,17 @@
add_convolved_sky_to_observations,
)
from .beam_synthesis import generate_gauss_beam_alms
from .constants import NUMBA_NUM_THREADS_ENVVAR
from .coordinates import CoordinateSystem
from .detectors import DetectorInfo, FreqChannelInfo, InstrumentInfo, UUID
from .detectors import UUID, DetectorInfo, FreqChannelInfo, InstrumentInfo
from .dipole import DipoleType, add_dipole_to_observations
from .distribute import distribute_evenly, distribute_optimally
from .gaindrifts import GainDriftType, GainDriftParams, apply_gaindrift_to_observations
from .gaindrifts import GainDriftParams, GainDriftType, apply_gaindrift_to_observations
from .healpix import write_healpix_map_to_file
from .hwp import HWP
from .hwp_diff_emiss import add_2f_to_observations
from .imo.imo import Imo
from .input_sky import SkyGenerationParams, SkyGenerator
from .io import read_list_of_observations, write_list_of_observations
from .mapmaking import (
BinnerResult,
Expand All @@ -54,8 +56,8 @@
make_destriped_map,
save_destriper_results,
)
from .input_sky import SkyGenerator, SkyGenerationParams
from .mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID
from .maps_and_harmonics import HealpixMap, SphericalHarmonics
from .mpi import MPI_COMM_GRID, MPI_COMM_WORLD, MPI_ENABLED
from .noise import add_noise_to_observations
from .non_linearity import NonLinParams, apply_quadratic_nonlin_to_observations
from .observations import Observation, TodDescription
Expand All @@ -65,13 +67,13 @@
from .scanning import ScanningStrategy, SpinningScanningStrategy
from .seeding import RNGHierarchy
from .spacecraft import SpacecraftOrbit, spacecraft_pos_and_vel
from .maps_and_harmonics import SphericalHarmonics, HealpixMap
from .units import Units
from .version import (
__version__ as litebird_sim_version,
__author__ as litebird_sim_author,
)
from .constants import NUMBA_NUM_THREADS_ENVVAR
from .version import (
__version__ as litebird_sim_version,
)

DEFAULT_BASE_IMO_URL = "https://litebirdimo.ssdc.asi.it"

Expand Down Expand Up @@ -2078,7 +2080,6 @@ def apply_quadratic_nonlin(
user_seed: int | None = None,
component: str = "tod",
append_to_report: bool = False,
rng_hierarchy: RNGHierarchy | None = None,
):
"""A method to apply non-linearity to the observation.

Expand All @@ -2087,11 +2088,7 @@ def apply_quadratic_nonlin(
applies non-linearity to a list of :class:`.Observation` instance. Random number generators are obtained from the detector-level layer. As default it uses
the `dets_random` field of a :class:`.Simulation` object for this.
"""
if rng_hierarchy is None:
rng_hierarchy = self.rng_hierarchy
dets_random = rng_hierarchy.get_detector_level_generators_on_rank(
self.mpi_comm.rank
)

if nl_params is None:
nl_params = NonLinParams()

Expand All @@ -2100,7 +2097,6 @@ def apply_quadratic_nonlin(
nl_params=nl_params,
user_seed=user_seed,
component=component,
dets_random=dets_random,
)

if append_to_report and MPI_COMM_WORLD.rank == 0:
Expand Down
29 changes: 13 additions & 16 deletions test/test_nonlinearity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from copy import deepcopy
import numpy as np

import litebird_sim as lbs
import numpy as np
from astropy.time import Time
from numpy.random import PCG64, Generator, SeedSequence


def test_add_quadratic_nonlinearity():
Expand All @@ -16,7 +18,6 @@ def test_add_quadratic_nonlinearity():
]

nl_params = lbs.NonLinParams(sampling_gaussian_loc=0.0, sampling_gaussian_scale=0.1)

random_seed = 12345
sim = lbs.Simulation(
base_path="nonlin_example",
Expand All @@ -41,23 +42,18 @@ def test_add_quadratic_nonlinearity():
component="nl_2_self",
)

# Applying non-linearity on the given TOD component of an `Observation` object
RNG_hierarchy = lbs.RNGHierarchy(
random_seed, comm_size=1, num_detectors_per_rank=len(dets)
)
dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(0)
lbs.apply_quadratic_nonlin_to_observations(
observations=sim.observations,
nl_params=nl_params,
component="nl_2_obs",
dets_random=dets_random,
)

# Applying non-linearity on the TOD arrays of the individual detectors.
RNG_hierarchy = lbs.RNGHierarchy(
random_seed, comm_size=1, num_detectors_per_rank=len(dets)
)
dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(0)
seeds = [sum(ord(c) for c in dn) for dn in sim.observations[0].name]
sg = SeedSequence(seeds)
dets_random = [
Generator(PCG64(s)) for s in sg.spawn(sim.observations[0].n_detectors)
]
for idx, tod in enumerate(sim.observations[0].nl_2_det):
lbs.apply_quadratic_nonlin_for_one_detector(
tod_det=tod,
Expand All @@ -75,10 +71,11 @@ def test_add_quadratic_nonlinearity():
)

# Check if non-linearity is applied correctly
RNG_hierarchy = lbs.RNGHierarchy(
random_seed, comm_size=1, num_detectors_per_rank=len(dets)
)
dets_random = RNG_hierarchy.get_detector_level_generators_on_rank(0)
seeds = [sum(ord(c) for c in dn) for dn in sim.observations[0].name]
sg = SeedSequence(seeds)
dets_random = [
Generator(PCG64(s)) for s in sg.spawn(sim.observations[0].n_detectors)
]
sim.observations[0].tod_origin = np.ones_like(sim.observations[0].tod)
for idx, tod in enumerate(sim.observations[0].nl_2_det):
g_one_over_k = dets_random[idx].normal(
Expand Down
Loading