diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index f36b59a0..1c077b56 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -43,7 +43,7 @@ def main(): sys.stdout = sys.stderr from policyengine_us_data.calibration.publish_local_area import ( - build_h5, + build_output_dataset, NYC_COUNTIES, NYC_CDS, AT_LARGE_DISTRICTS, @@ -104,11 +104,11 @@ def main(): continue states_dir = output_dir / "states" states_dir.mkdir(parents=True, exist_ok=True) - path = build_h5( + result = build_output_dataset( weights=weights, geography=geography, dataset_path=dataset_path, - output_path=states_dir / f"{item_id}.h5", + output_base=states_dir / item_id, cd_subset=cd_subset, takeup_filter=takeup_filter, ) @@ -147,11 +147,11 @@ def main(): districts_dir = output_dir / "districts" districts_dir.mkdir(parents=True, exist_ok=True) - path = build_h5( + result = build_output_dataset( weights=weights, geography=geography, dataset_path=dataset_path, - output_path=districts_dir / f"{friendly_name}.h5", + output_base=districts_dir / friendly_name, cd_subset=[geoid], takeup_filter=takeup_filter, ) @@ -166,11 +166,11 @@ def main(): continue cities_dir = output_dir / "cities" cities_dir.mkdir(parents=True, exist_ok=True) - path = build_h5( + result = build_output_dataset( weights=weights, geography=geography, dataset_path=dataset_path, - output_path=cities_dir / "NYC.h5", + output_base=cities_dir / "NYC", cd_subset=cd_subset, county_filter=NYC_COUNTIES, takeup_filter=takeup_filter, @@ -179,16 +179,16 @@ def main(): elif item_type == "national": national_dir = output_dir / "national" national_dir.mkdir(parents=True, exist_ok=True) - path = build_h5( + result = build_output_dataset( weights=weights, geography=geography, dataset_path=dataset_path, - output_path=national_dir / "US.h5", + output_base=national_dir / "US", ) else: raise ValueError(f"Unknown item type: {item_type}") - if path: + if result: results["completed"].append(f"{item_type}:{item_id}") print( f"Completed {item_type}:{item_id}", diff --git a/policyengine_us_data/calibration/publish_local_area.py b/policyengine_us_data/calibration/publish_local_area.py index 72594631..0b47c4a9 100644 --- a/policyengine_us_data/calibration/publish_local_area.py +++ b/policyengine_us_data/calibration/publish_local_area.py @@ -8,11 +8,17 @@ python publish_local_area.py [--skip-download] [--states-only] [--upload] """ +import os import numpy as np from pathlib import Path from typing import List from policyengine_us import Microsimulation +from policyengine_us_data.utils.hdfstore import ( + DatasetResult, + save_h5, + save_hdfstore, +) from policyengine_us_data.utils.huggingface import download_calibration_inputs from policyengine_us_data.utils.data_upload import ( upload_local_area_file, @@ -105,39 +111,39 @@ def record_completed_city(city_name: str): f.write(f"{city_name}\n") -def build_h5( +def build_output_dataset( weights: np.ndarray, geography, dataset_path: Path, - output_path: Path, + output_base: Path, cd_subset: List[str] = None, county_filter: set = None, takeup_filter: List[str] = None, -) -> Path: - """Build an H5 file by cloning records for each nonzero weight. +) -> DatasetResult: + """Assemble a dataset and serialize to h5py + HDFStore. Args: weights: Clone-level weight vector, shape (n_clones_total * n_hh,). geography: GeographyAssignment from assign_random_geography. dataset_path: Path to base dataset H5 file. - output_path: Where to write the output H5 file. + output_base: Path stem **without** file extension. + Serializers append ``.h5`` and ``.hdfstore.h5``. cd_subset: If provided, only include clones for these CDs. county_filter: If provided, scale weights by P(target|CD) for city datasets. takeup_filter: List of takeup vars to apply. Returns: - Path to the output H5 file. + A :class:`DatasetResult` with the assembled data. """ - import h5py from collections import defaultdict from policyengine_core.enums import Enum from policyengine_us.variables.household.demographic.geographic.county.county_enum import ( County, ) - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) + output_base = Path(output_base) + output_base.parent.mkdir(parents=True, exist_ok=True) blocks = np.asarray(geography.block_geoid) clone_cds = np.asarray(geography.cd_geoid, dtype=str) @@ -182,7 +188,7 @@ def build_h5( else f"{n_clones_total} clone rows" ) print(f"\n{'=' * 60}") - print(f"Building {output_path.name} ({label}, {n_hh} households)") + print(f"Building {output_base.name} ({label}, {n_hh} households)") print(f"{'=' * 60}") # === Identify active clones === @@ -539,31 +545,16 @@ def build_h5( for var_name, bools in takeup_results.items(): data[var_name] = {time_period: bools} - # === Write H5 === - with h5py.File(str(output_path), "w") as f: - for variable, periods in data.items(): - grp = f.create_group(variable) - for period, values in periods.items(): - grp.create_dataset(str(period), data=values) - - print(f"\nH5 saved to {output_path}") - - with h5py.File(str(output_path), "r") as f: - tp = str(time_period) - if "household_id" in f and tp in f["household_id"]: - n = len(f["household_id"][tp][:]) - print(f"Verified: {n:,} households in output") - if "person_id" in f and tp in f["person_id"]: - n = len(f["person_id"][tp][:]) - print(f"Verified: {n:,} persons in output") - if "household_weight" in f and tp in f["household_weight"]: - hw = f["household_weight"][tp][:] - print(f"Total population (HH weights): {hw.sum():,.0f}") - if "person_weight" in f and tp in f["person_weight"]: - pw = f["person_weight"][tp][:] - print(f"Total population (person weights): {pw.sum():,.0f}") - - return output_path + # === Serialize === + result = DatasetResult( + data=data, + time_period=time_period, + system=sim.tax_benefit_system, + ) + save_h5(result, str(output_base)) + save_hdfstore(result, str(output_base)) + + return result AT_LARGE_DISTRICTS = {0, 98} @@ -613,22 +604,35 @@ def build_states( print(f"No CDs found for {state_code}, skipping") continue - output_path = states_dir / f"{state_code}.h5" + output_base = states_dir / state_code try: - build_h5( + build_output_dataset( weights=w, geography=geography, dataset_path=dataset_path, - output_path=output_path, + output_base=output_base, cd_subset=cd_subset, takeup_filter=takeup_filter, ) + h5_path = str(output_base) + ".h5" + hdfstore_path = str(output_base) + ".hdfstore.h5" + has_hdfstore = os.path.exists(hdfstore_path) + if upload: print(f"Uploading {state_code}.h5 to GCP...") - upload_local_area_file(str(output_path), "states", skip_hf=True) - hf_queue.append((str(output_path), "states")) + upload_local_area_file(h5_path, "states", skip_hf=True) + + if has_hdfstore: + print(f"Uploading {state_code}.hdfstore.h5 to GCP...") + upload_local_area_file( + hdfstore_path, "states_hdfstore", skip_hf=True + ) + + hf_queue.append((h5_path, "states")) + if has_hdfstore: + hf_queue.append((hdfstore_path, "states_hdfstore")) record_completed_state(state_code) print(f"Completed {state_code}") @@ -680,23 +684,36 @@ def build_districts( print(f"Skipping {friendly_name} (already completed)") continue - output_path = districts_dir / f"{friendly_name}.h5" + output_base = districts_dir / friendly_name print(f"\n[{i + 1}/{len(all_cds)}] Building {friendly_name}") try: - build_h5( + build_output_dataset( weights=w, geography=geography, dataset_path=dataset_path, - output_path=output_path, + output_base=output_base, cd_subset=[cd_geoid], takeup_filter=takeup_filter, ) + h5_path = str(output_base) + ".h5" + hdfstore_path = str(output_base) + ".hdfstore.h5" + has_hdfstore = os.path.exists(hdfstore_path) + if upload: print(f"Uploading {friendly_name}.h5 to GCP...") - upload_local_area_file(str(output_path), "districts", skip_hf=True) - hf_queue.append((str(output_path), "districts")) + upload_local_area_file(h5_path, "districts", skip_hf=True) + + if has_hdfstore: + print(f"Uploading {friendly_name}.hdfstore.h5 to GCP...") + upload_local_area_file( + hdfstore_path, "districts_hdfstore", skip_hf=True + ) + + hf_queue.append((h5_path, "districts")) + if has_hdfstore: + hf_queue.append((hdfstore_path, "districts_hdfstore")) record_completed_district(friendly_name) print(f"Completed {friendly_name}") @@ -743,23 +760,36 @@ def build_cities( if not cd_subset: print("No NYC-related CDs found, skipping") else: - output_path = cities_dir / "NYC.h5" + output_base = cities_dir / "NYC" try: - build_h5( + build_output_dataset( weights=w, geography=geography, dataset_path=dataset_path, - output_path=output_path, + output_base=output_base, cd_subset=cd_subset, county_filter=NYC_COUNTIES, takeup_filter=takeup_filter, ) + h5_path = str(output_base) + ".h5" + hdfstore_path = str(output_base) + ".hdfstore.h5" + has_hdfstore = os.path.exists(hdfstore_path) + if upload: print("Uploading NYC.h5 to GCP...") - upload_local_area_file(str(output_path), "cities", skip_hf=True) - hf_queue.append((str(output_path), "cities")) + upload_local_area_file(h5_path, "cities", skip_hf=True) + + if has_hdfstore: + print("Uploading NYC.hdfstore.h5 to GCP...") + upload_local_area_file( + hdfstore_path, "cities_hdfstore", skip_hf=True + ) + + hf_queue.append((h5_path, "cities")) + if has_hdfstore: + hf_queue.append((hdfstore_path, "cities_hdfstore")) record_completed_city("NYC") print("Completed NYC") diff --git a/policyengine_us_data/calibration/stacked_dataset_builder.py b/policyengine_us_data/calibration/stacked_dataset_builder.py index 0089f0d1..466d8e5b 100644 --- a/policyengine_us_data/calibration/stacked_dataset_builder.py +++ b/policyengine_us_data/calibration/stacked_dataset_builder.py @@ -1,7 +1,7 @@ """ CLI for creating CD-stacked datasets from calibration artifacts. -Thin wrapper around build_h5/build_states/build_districts/build_cities +Thin wrapper around build_output_dataset/build_states/build_districts/build_cities in publish_local_area.py. Loads a GeographyAssignment from geography.npz and delegates all H5 building logic. """ @@ -19,7 +19,7 @@ from policyengine_us import Microsimulation from policyengine_us_data.calibration.publish_local_area import ( - build_h5, + build_output_dataset, build_states, build_districts, build_cities, @@ -111,13 +111,13 @@ # === Dispatch === if mode == "national": - output_path = output_dir / "US.h5" - print(f"\nCreating national dataset: {output_path}") - build_h5( + output_base = output_dir / "US" + print(f"\nCreating national dataset: {output_base}") + build_output_dataset( weights=w, geography=geography, dataset_path=dataset_path, - output_path=output_path, + output_base=output_base, takeup_filter=takeup_filter, ) @@ -160,13 +160,13 @@ calibrated_cds = sorted(set(cd_geoid)) if args.cd not in calibrated_cds: raise ValueError(f"CD {args.cd} not in calibrated CDs") - output_path = output_dir / f"{args.cd}.h5" - print(f"\nCreating single CD dataset: {output_path}") - build_h5( + output_base = output_dir / args.cd + print(f"\nCreating single CD dataset: {output_base}") + build_output_dataset( weights=w, geography=geography, dataset_path=dataset_path, - output_path=output_path, + output_base=output_base, cd_subset=[args.cd], takeup_filter=takeup_filter, ) diff --git a/policyengine_us_data/db/etl_age.py b/policyengine_us_data/db/etl_age.py index db5e54da..658b76f0 100644 --- a/policyengine_us_data/db/etl_age.py +++ b/policyengine_us_data/db/etl_age.py @@ -92,7 +92,6 @@ def transform_age_data(age_data, docs): def load_age_data(df_long, geo, year): - # Quick data quality check before loading ---- if geo == "National": assert len(set(df_long.ucgid_str)) == 1 diff --git a/policyengine_us_data/db/etl_irs_soi.py b/policyengine_us_data/db/etl_irs_soi.py index f2b17795..b5999e48 100644 --- a/policyengine_us_data/db/etl_irs_soi.py +++ b/policyengine_us_data/db/etl_irs_soi.py @@ -238,7 +238,6 @@ def extract_soi_data() -> pd.DataFrame: def transform_soi_data(raw_df): - TARGETS = [ dict(code="59661", name="eitc", breakdown=("eitc_child_count", 0)), dict(code="59662", name="eitc", breakdown=("eitc_child_count", 1)), diff --git a/policyengine_us_data/db/etl_medicaid.py b/policyengine_us_data/db/etl_medicaid.py index 2c467799..30bae90a 100644 --- a/policyengine_us_data/db/etl_medicaid.py +++ b/policyengine_us_data/db/etl_medicaid.py @@ -150,7 +150,6 @@ def transform_survey_medicaid_data(cd_survey_df): def load_medicaid_data(long_state, long_cd, year): - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" engine = create_engine(DATABASE_URL) diff --git a/policyengine_us_data/db/etl_snap.py b/policyengine_us_data/db/etl_snap.py index dc5975a4..d21260d1 100644 --- a/policyengine_us_data/db/etl_snap.py +++ b/policyengine_us_data/db/etl_snap.py @@ -153,7 +153,6 @@ def transform_survey_snap_data(raw_df): def load_administrative_snap_data(df_states, year): - DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" engine = create_engine(DATABASE_URL) diff --git a/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py b/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py index 339dec4e..d3054ce8 100644 --- a/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py +++ b/policyengine_us_data/tests/test_calibration/test_stacked_dataset_builder.py @@ -1,4 +1,4 @@ -"""Tests for build_h5 using deterministic test fixture.""" +"""Tests for build_output_dataset using deterministic test fixture.""" import os import tempfile @@ -9,7 +9,7 @@ from pathlib import Path from policyengine_us import Microsimulation from policyengine_us_data.calibration.publish_local_area import ( - build_h5, + build_output_dataset, ) from policyengine_us_data.calibration.clone_and_assign import ( GeographyAssignment, @@ -83,17 +83,17 @@ def stacked_result(test_weights, n_households): """Run stacked dataset builder and return results.""" geography = _make_geography(n_households, TEST_CDS) with tempfile.TemporaryDirectory() as tmpdir: - output_path = os.path.join(tmpdir, "test_output.h5") + output_base = os.path.join(tmpdir, "test_output") - build_h5( + build_output_dataset( weights=np.array(test_weights), geography=geography, dataset_path=Path(FIXTURE_PATH), - output_path=Path(output_path), + output_base=Path(output_base), cd_subset=TEST_CDS, ) - sim_after = Microsimulation(dataset=output_path) + sim_after = Microsimulation(dataset=output_base + ".h5") hh_df = pd.DataFrame( sim_after.calculate_dataframe( [ @@ -172,17 +172,17 @@ def stacked_sim(test_weights, n_households): """Run stacked dataset builder and return the simulation.""" geography = _make_geography(n_households, TEST_CDS) with tempfile.TemporaryDirectory() as tmpdir: - output_path = os.path.join(tmpdir, "test_output.h5") + output_base = os.path.join(tmpdir, "test_output") - build_h5( + build_output_dataset( weights=np.array(test_weights), geography=geography, dataset_path=Path(FIXTURE_PATH), - output_path=Path(output_path), + output_base=Path(output_base), cd_subset=TEST_CDS, ) - sim = Microsimulation(dataset=output_path) + sim = Microsimulation(dataset=output_base + ".h5") yield sim @@ -197,15 +197,15 @@ def stacked_sim_with_overlap(n_households): geography = _make_geography(n_households, TEST_CDS) with tempfile.TemporaryDirectory() as tmpdir: - output_path = os.path.join(tmpdir, "test_overlap.h5") - build_h5( + output_base = os.path.join(tmpdir, "test_overlap") + build_output_dataset( weights=np.array(w), geography=geography, dataset_path=Path(FIXTURE_PATH), - output_path=Path(output_path), + output_base=Path(output_base), cd_subset=TEST_CDS, ) - sim = Microsimulation(dataset=output_path) + sim = Microsimulation(dataset=output_base + ".h5") yield { "sim": sim, "n_overlap": len(overlap_households), diff --git a/policyengine_us_data/tests/test_calibration/test_xw_consistency.py b/policyengine_us_data/tests/test_calibration/test_xw_consistency.py index 403fe1af..fac6bd52 100644 --- a/policyengine_us_data/tests/test_calibration/test_xw_consistency.py +++ b/policyengine_us_data/tests/test_calibration/test_xw_consistency.py @@ -44,7 +44,7 @@ def test_xw_matches_stacked_sim(): UnifiedMatrixBuilder, ) from policyengine_us_data.calibration.publish_local_area import ( - build_h5, + build_output_dataset, ) from policyengine_us_data.utils.takeup import ( TAKEUP_AFFECTED_TARGETS, @@ -105,17 +105,17 @@ def test_xw_matches_stacked_sim(): tmpdir = tempfile.mkdtemp() for cd in top_cds: - h5_path = f"{tmpdir}/{cd}.h5" - build_h5( + output_base = f"{tmpdir}/{cd}" + build_output_dataset( weights=w, geography=geography, dataset_path=Path(DATASET_PATH), - output_path=Path(h5_path), + output_base=Path(output_base), cd_subset=[cd], takeup_filter=takeup_filter, ) - stacked_sim = Microsimulation(dataset=h5_path) + stacked_sim = Microsimulation(dataset=output_base + ".h5") hh_weight = stacked_sim.calculate( "household_weight", 2024, map_to="household" ).values diff --git a/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py b/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py index 298de5a4..5ebad600 100644 --- a/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py +++ b/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py @@ -145,7 +145,6 @@ def test_undocumented_matches_ssn_none(): def test_aca_calibration(): - import pandas as pd from pathlib import Path from policyengine_us import Microsimulation @@ -231,7 +230,6 @@ def test_immigration_status_diversity(): def test_medicaid_calibration(): - import pandas as pd from pathlib import Path from policyengine_us import Microsimulation diff --git a/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py b/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py index a7ee941b..0314d8e4 100644 --- a/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py +++ b/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py @@ -123,7 +123,6 @@ def test_sparse_ecps_replicates_jct_tax_expenditures(): def deprecated_test_sparse_ecps_replicates_jct_tax_expenditures_full(sim): - # JCT tax expenditure targets EXPENDITURE_TARGETS = { "salt_deduction": 21.247e9, @@ -158,7 +157,6 @@ def apply(self): def test_sparse_ssn_card_type_none_target(sim): - TARGET_COUNT = 13e6 TOLERANCE = 0.2 # Allow 20% error @@ -176,7 +174,6 @@ def test_sparse_ssn_card_type_none_target(sim): def test_sparse_aca_calibration(sim): - TARGETS_PATH = Path( "policyengine_us_data/storage/calibration_targets/aca_spending_and_enrollment_2024.csv" ) @@ -210,7 +207,6 @@ def test_sparse_aca_calibration(sim): def test_sparse_medicaid_calibration(sim): - TARGETS_PATH = Path( "policyengine_us_data/storage/calibration_targets/medicaid_enrollment_2024.csv" ) diff --git a/policyengine_us_data/tests/test_format_comparison.py b/policyengine_us_data/tests/test_format_comparison.py new file mode 100644 index 00000000..e5d653a1 --- /dev/null +++ b/policyengine_us_data/tests/test_format_comparison.py @@ -0,0 +1,401 @@ +""" +ONE-OFF VALIDATION SCRIPT + +This is a one-off script used to verify that the h5py-to-HDFStore +conversion logic is correct. It reads an existing h5py dataset file, +converts it to entity-level Pandas HDFStore using the production +splitting/dedup logic, then compares all variables to verify the +conversion is lossless. + +This script is NOT part of the regular test suite and is not intended to +be run in CI. It exists to validate the HDFStore serialization logic +during development. + +Usage (run directly to avoid policyengine_us_data __init__ imports): + python policyengine_us_data/tests/test_format_comparison.py \ + --h5py-path path/to/STATE.h5 +""" + +import argparse +import sys + +import h5py +import numpy as np +import pandas as pd +import pytest + +from policyengine_us_data.utils.hdfstore import ( + ENTITIES, + DatasetResult, + save_hdfstore, +) + + +def _load_system(): + """Load the policyengine-us tax-benefit system.""" + from policyengine_us import system as us_system + + return us_system.system + + +# --------------------------------------------------------------------------- +# h5py reading helpers (test-specific; reads the flat h5py format +# and wraps it into the nested {var: {period: array}} structure +# expected by the production HDFStore utilities) +# --------------------------------------------------------------------------- + + +def _read_h5py_arrays(h5py_path: str): + """Read all arrays from an h5py variable-centric file. + + The h5py format stores ``variable / period -> array``. Periods can be + yearly (``"2024"``), monthly (``"2024-01"``), or ``"ETERNITY"``. + + Returns ``(data, time_period, h5_vars)`` where *data* is a nested dict + ``{variable_name: {period_key: numpy_array}}`` matching the format + used by the production HDFStore utilities. + """ + with h5py.File(h5py_path, "r") as f: + h5_vars = sorted(f.keys()) + + # Determine the canonical year from the first variable that has one + year = None + for var in h5_vars: + subkeys = list(f[var].keys()) + for sk in subkeys: + if sk.isdigit() and len(sk) == 4: + year = sk + break + if year is not None: + break + if year is None: + raise ValueError("Could not determine year from h5py file") + + time_period = int(year) + data = {} + + for var in h5_vars: + subkeys = list(f[var].keys()) + if year in subkeys: + period_key = year + elif "ETERNITY" in subkeys: + period_key = "ETERNITY" + else: + period_key = subkeys[0] + + arr = f[var][period_key][:] + if arr.dtype.kind in ("S", "O"): + arr = np.array( + [x.decode() if isinstance(x, bytes) else str(x) for x in arr] + ) + # Wrap in nested dict keyed by the period string + data[var] = {period_key: arr} + + return data, time_period, h5_vars + + +# --------------------------------------------------------------------------- +# Main conversion + comparison logic +# --------------------------------------------------------------------------- + + +def h5py_to_hdfstore(h5py_path: str, hdfstore_path: str) -> dict: + """Convert an h5py variable-centric file to entity-level HDFStore. + + Uses the production HDFStore utilities so this test validates the + real code path rather than a local reimplementation. + + Returns a summary dict with entity row counts. + """ + print("Loading policyengine-us system (this takes a minute)...") + system = _load_system() + + print("Reading h5py file...") + data, time_period, h5_vars = _read_h5py_arrays(h5py_path) + n_persons = len(next(iter(data.get("person_id", {}).values()), [])) + print(f" {len(h5_vars)} variables, {n_persons:,} persons, year={time_period}") + + result = DatasetResult(data=data, time_period=time_period, system=system) + output_base = hdfstore_path.replace(".hdfstore.h5", "") + + print(f"Saving HDFStore to {hdfstore_path}...") + save_hdfstore(result, output_base) + + summary = {} + with pd.HDFStore(hdfstore_path, "r") as store: + for k in store.keys(): + if not k.startswith("/_"): + name = k.lstrip("/") + df = store[k] + summary[name] = {"rows": len(df), "cols": len(df.columns)} + if "/_variable_metadata" in store.keys(): + summary["manifest_vars"] = len(store["/_variable_metadata"]) + return summary + + +def compare_formats(h5py_path: str, hdfstore_path: str) -> dict: + """Compare all variables between h5py and generated HDFStore. + + Returns a dict with keys: passed, failed, skipped. + """ + passed = [] + failed = [] + skipped = [] + + with h5py.File(h5py_path, "r") as f: + h5_vars = sorted(f.keys()) + + # Determine the year + year = None + for var in h5_vars: + for sk in f[var].keys(): + if sk.isdigit() and len(sk) == 4: + year = sk + break + if year is not None: + break + + with pd.HDFStore(hdfstore_path, "r") as store: + store_keys = [k for k in store.keys() if not k.startswith("/_")] + entity_dfs = {k: store[k] for k in store_keys} + + for var in h5_vars: + subkeys = list(f[var].keys()) + if year in subkeys: + period_key = year + elif "ETERNITY" in subkeys: + period_key = "ETERNITY" + else: + period_key = subkeys[0] + + h5_values = f[var][period_key][:] + + found = False + for entity_key, df in entity_dfs.items(): + entity_name = entity_key.lstrip("/") + if var in df.columns: + hdf_values = df[var].values + + # For group entities, h5py is person-level while + # HDFStore is deduplicated by entity ID. + if entity_name != "person" and len(hdf_values) != len( + h5_values + ): + h5_unique = np.unique(h5_values) + hdf_unique = np.unique(hdf_values) + if h5_values.dtype.kind in ("U", "S", "O"): + match = set( + (x.decode() if isinstance(x, bytes) else str(x)) + for x in h5_unique + ) == set(str(x) for x in hdf_unique) + else: + match = np.allclose( + np.sort(h5_unique.astype(float)), + np.sort(hdf_unique.astype(float)), + rtol=1e-5, + equal_nan=True, + ) + if match: + passed.append(var) + else: + failed.append( + ( + var, + f"unique values differ " + f"(h5py: {len(h5_unique)}, " + f"hdfstore: {len(hdf_unique)})", + ) + ) + else: + # Same length — direct comparison + if h5_values.dtype.kind in ("U", "S", "O"): + h5_str = np.array( + [ + (x.decode() if isinstance(x, bytes) else str(x)) + for x in h5_values + ] + ) + hdf_str = np.array([str(x) for x in hdf_values]) + if np.array_equal(h5_str, hdf_str): + passed.append(var) + else: + mismatches = np.sum(h5_str != hdf_str) + failed.append( + ( + var, + f"{mismatches} string mismatches", + ) + ) + else: + h5_float = h5_values.astype(float) + hdf_float = hdf_values.astype(float) + if np.allclose( + h5_float, + hdf_float, + rtol=1e-5, + equal_nan=True, + ): + passed.append(var) + else: + diff = np.abs(h5_float - hdf_float) + max_diff = np.max(diff) + n_diff = np.sum( + ~np.isclose( + h5_float, + hdf_float, + rtol=1e-5, + equal_nan=True, + ) + ) + failed.append( + ( + var, + f"{n_diff} values differ, " + f"max diff={max_diff:.6f}", + ) + ) + found = True + break + + if not found: + skipped.append(var) + + return { + "passed": passed, + "failed": failed, + "skipped": skipped, + "total_h5py_vars": len(h5_vars), + } + + +def print_results(result): + """Print comparison results to stdout.""" + print(f"\n{'=' * 60}") + print("Format Comparison Results") + print(f"{'=' * 60}") + print(f"Total h5py variables: {result['total_h5py_vars']}") + print(f"Passed: {len(result['passed'])}") + print(f"Failed: {len(result['failed'])}") + print(f"Skipped (not in HDFStore): {len(result['skipped'])}") + + if result["failed"]: + print("\nFailed variables:") + for var, reason in result["failed"]: + print(f" {var}: {reason}") + + if result["skipped"]: + print("\nSkipped variables (not found in HDFStore):") + for var in result["skipped"]: + print(f" {var}") + + +# --- pytest interface --- + + +def pytest_addoption(parser): + parser.addoption("--h5py-path", action="store", default=None) + + +@pytest.fixture +def h5py_path(request): + path = request.config.getoption("--h5py-path") + if path is None: + pytest.skip("--h5py-path not provided") + return path + + +def test_roundtrip(h5py_path, tmp_path): + """Convert h5py -> HDFStore -> compare all variables.""" + hdfstore_path = str(tmp_path / "test_output.hdfstore.h5") + + summary = h5py_to_hdfstore(h5py_path, hdfstore_path) + for entity, info in summary.items(): + if isinstance(info, dict): + print(f" {entity}: {info['rows']:,} rows, {info['cols']} cols") + + result = compare_formats(h5py_path, hdfstore_path) + print_results(result) + + assert len(result["failed"]) == 0, ( + f"{len(result['failed'])} variables have mismatched values" + ) + assert len(result["skipped"]) == 0, ( + f"{len(result['skipped'])} variables missing from HDFStore" + ) + + +def test_manifest(h5py_path, tmp_path): + """Verify the generated HDFStore contains a valid manifest.""" + hdfstore_path = str(tmp_path / "test_output.hdfstore.h5") + h5py_to_hdfstore(h5py_path, hdfstore_path) + + with pd.HDFStore(hdfstore_path, "r") as store: + assert "/_variable_metadata" in store.keys(), "Missing _variable_metadata table" + manifest = store["/_variable_metadata"] + assert "variable" in manifest.columns + assert "entity" in manifest.columns + assert "uprating" in manifest.columns + assert len(manifest) > 0, "Manifest is empty" + print(f"\nManifest has {len(manifest)} variables") + print(f"Entities: {manifest['entity'].unique().tolist()}") + n_uprated = (manifest["uprating"] != "").sum() + print(f"Variables with uprating: {n_uprated}") + + +def test_all_entities(h5py_path, tmp_path): + """Verify the generated HDFStore contains all expected entity tables.""" + hdfstore_path = str(tmp_path / "test_output.hdfstore.h5") + h5py_to_hdfstore(h5py_path, hdfstore_path) + + expected = set(ENTITIES) + with pd.HDFStore(hdfstore_path, "r") as store: + actual = {k.lstrip("/") for k in store.keys() if not k.startswith("/_")} + missing = expected - actual + assert not missing, f"Missing entity tables: {missing}" + for entity in expected: + df = store[f"/{entity}"] + assert len(df) > 0, f"Entity {entity} has 0 rows" + assert f"{entity}_id" in df.columns, ( + f"Entity {entity} missing {entity}_id column" + ) + print(f" {entity}: {len(df):,} rows, {len(df.columns)} cols") + + +# --- CLI interface --- + + +if __name__ == "__main__": + from pathlib import Path as _Path + + parser = argparse.ArgumentParser( + description="Convert h5py dataset to HDFStore and verify roundtrip" + ) + parser.add_argument("--h5py-path", required=True, help="Path to h5py format file") + parser.add_argument( + "--output-path", + default=None, + help="Path for generated HDFStore (default: alongside input file)", + ) + args = parser.parse_args() + + if args.output_path: + hdfstore_path = args.output_path + else: + p = _Path(args.h5py_path) + hdfstore_path = str(p.with_suffix("")) + ".hdfstore.h5" + + print(f"Converting {args.h5py_path} -> {hdfstore_path}...") + summary = h5py_to_hdfstore(args.h5py_path, hdfstore_path) + for entity, info in summary.items(): + if isinstance(info, dict): + print(f" {entity}: {info['rows']:,} rows, {info['cols']} cols") + + print("\nComparing formats...") + result = compare_formats(args.h5py_path, hdfstore_path) + print_results(result) + + if result["failed"] or result["skipped"]: + sys.exit(1) + else: + print("\nAll variables match!") + sys.exit(0) diff --git a/policyengine_us_data/tests/test_hdfstore.py b/policyengine_us_data/tests/test_hdfstore.py new file mode 100644 index 00000000..c2cb6785 --- /dev/null +++ b/policyengine_us_data/tests/test_hdfstore.py @@ -0,0 +1,180 @@ +"""Unit tests for policyengine_us_data.utils.hdfstore utilities.""" + +import h5py +import numpy as np +import pandas as pd +import pytest + +from policyengine_us_data.utils.hdfstore import ( + ENTITIES, + DatasetResult, + _resolve_period_key, + save_h5, + save_hdfstore, +) + + +# ------------------------------------------------------------------- +# _resolve_period_key +# ------------------------------------------------------------------- + + +class TestResolvePeriodKey: + def test_int_key(self): + assert _resolve_period_key({2024: "a"}, 2024) == 2024 + + def test_str_key(self): + assert _resolve_period_key({"2024": "a"}, 2024) == "2024" + + def test_int_preferred_over_str(self): + assert _resolve_period_key({2024: "a", "2024": "b"}, 2024) == 2024 + + def test_eternity_fallback(self): + assert _resolve_period_key({"ETERNITY": "a"}, 2024) == "ETERNITY" + + def test_arbitrary_fallback(self): + assert _resolve_period_key({"2024-01": "a"}, 2024) == "2024-01" + + def test_empty_returns_none(self): + assert _resolve_period_key({}, 2024) is None + + +# ------------------------------------------------------------------- +# save_h5 / save_hdfstore round-trip +# ------------------------------------------------------------------- + + +class _FakeVariable: + """Minimal stand-in for a PolicyEngine variable.""" + + def __init__(self, entity_key, uprating=""): + self.entity = type("E", (), {"key": entity_key})() + self.uprating = uprating + + +class _FakeSystem: + """Minimal stand-in for a TaxBenefitSystem.""" + + def __init__(self, var_map): + self.variables = var_map + + +def _make_tiny_result(time_period=2024): + """Build a DatasetResult with a handful of rows per entity.""" + n = 5 + data = { + "person_id": {time_period: np.arange(n, dtype=np.int64)}, + "household_id": {time_period: np.arange(n, dtype=np.int64)}, + "tax_unit_id": {time_period: np.arange(n, dtype=np.int64)}, + "spm_unit_id": {time_period: np.arange(n, dtype=np.int64)}, + "family_id": {time_period: np.arange(n, dtype=np.int64)}, + "marital_unit_id": {time_period: np.arange(n, dtype=np.int64)}, + "person_household_id": {time_period: np.arange(n, dtype=np.int64)}, + "person_tax_unit_id": {time_period: np.arange(n, dtype=np.int64)}, + "person_spm_unit_id": {time_period: np.arange(n, dtype=np.int64)}, + "person_family_id": {time_period: np.arange(n, dtype=np.int64)}, + "person_marital_unit_id": {time_period: np.arange(n, dtype=np.int64)}, + "age": {time_period: np.array([25, 30, 5, 60, 45])}, + "household_weight": {time_period: np.ones(n)}, + } + + variables = { + "person_id": _FakeVariable("person"), + "household_id": _FakeVariable("household"), + "tax_unit_id": _FakeVariable("tax_unit"), + "spm_unit_id": _FakeVariable("spm_unit"), + "family_id": _FakeVariable("family"), + "marital_unit_id": _FakeVariable("marital_unit"), + "age": _FakeVariable("person"), + "household_weight": _FakeVariable("household", uprating="cpi"), + } + + return DatasetResult( + data=data, + time_period=time_period, + system=_FakeSystem(variables), + ) + + +def test_save_h5_roundtrip(tmp_path): + """save_h5 writes a file that h5py can read back identically.""" + result = _make_tiny_result() + output_base = str(tmp_path / "test") + h5_path = save_h5(result, output_base) + + assert h5_path == output_base + ".h5" + + with h5py.File(h5_path, "r") as f: + assert "age" in f + assert "2024" in f["age"] + np.testing.assert_array_equal(f["age"]["2024"][:], result.data["age"][2024]) + + +def test_save_hdfstore_roundtrip(tmp_path): + """save_hdfstore writes entity tables readable by pd.HDFStore.""" + result = _make_tiny_result() + output_base = str(tmp_path / "test") + hdfstore_path = save_hdfstore(result, output_base) + + assert hdfstore_path == output_base + ".hdfstore.h5" + + with pd.HDFStore(hdfstore_path, "r") as store: + keys = {k.lstrip("/") for k in store.keys()} + for entity in ENTITIES: + assert entity in keys, f"Missing entity: {entity}" + df = store[f"/{entity}"] + assert f"{entity}_id" in df.columns + + assert "_variable_metadata" in keys + manifest = store["/_variable_metadata"] + assert "variable" in manifest.columns + assert "entity" in manifest.columns + assert "uprating" in manifest.columns + assert len(manifest) > 0 + + assert "_time_period" in keys + tp = store["/_time_period"] + assert tp.iloc[0] == 2024 + + +def test_save_hdfstore_does_not_mutate_input(tmp_path): + """save_hdfstore should not modify the DatasetResult's data.""" + result = _make_tiny_result() + original_age = result.data["age"][2024].copy() + + save_hdfstore(result, str(tmp_path / "test")) + + np.testing.assert_array_equal(result.data["age"][2024], original_age) + + +def test_eternity_variable_included(tmp_path): + """Variables keyed by ETERNITY should appear in the HDFStore.""" + n = 3 + data = { + "person_id": {2024: np.arange(n, dtype=np.int64)}, + "household_id": {2024: np.arange(n, dtype=np.int64)}, + "tax_unit_id": {2024: np.arange(n, dtype=np.int64)}, + "spm_unit_id": {2024: np.arange(n, dtype=np.int64)}, + "family_id": {2024: np.arange(n, dtype=np.int64)}, + "marital_unit_id": {2024: np.arange(n, dtype=np.int64)}, + "person_household_id": {2024: np.arange(n, dtype=np.int64)}, + "is_male": {"ETERNITY": np.array([True, False, True])}, + } + variables = { + "person_id": _FakeVariable("person"), + "household_id": _FakeVariable("household"), + "tax_unit_id": _FakeVariable("tax_unit"), + "spm_unit_id": _FakeVariable("spm_unit"), + "family_id": _FakeVariable("family"), + "marital_unit_id": _FakeVariable("marital_unit"), + "is_male": _FakeVariable("person"), + } + result = DatasetResult(data=data, time_period=2024, system=_FakeSystem(variables)) + + hdfstore_path = save_hdfstore(result, str(tmp_path / "test")) + + with pd.HDFStore(hdfstore_path, "r") as store: + person_df = store["/person"] + assert "is_male" in person_df.columns, ( + "ETERNITY-keyed variable missing from person entity" + ) diff --git a/policyengine_us_data/utils/hdfstore.py b/policyengine_us_data/utils/hdfstore.py new file mode 100644 index 00000000..010fb05d --- /dev/null +++ b/policyengine_us_data/utils/hdfstore.py @@ -0,0 +1,239 @@ +""" +Dataset serialization utilities. + +Provides two serializers used by ``build_output_dataset``: + +* ``save_h5`` – variable-centric h5py format +* ``save_hdfstore`` – entity-level Pandas HDFStore consumed by API v2 +""" + +import warnings +from dataclasses import dataclass +from typing import Any, Dict + +import h5py +import numpy as np +import pandas as pd + +ENTITIES = [ + "person", + "household", + "tax_unit", + "spm_unit", + "family", + "marital_unit", +] + + +@dataclass +class DatasetResult: + """Typed container returned by ``build_output_dataset``.""" + + data: Dict[str, Dict] # {var_name: {period: np.ndarray}} + time_period: int + system: Any # TaxBenefitSystem + + +# ------------------------------------------------------------------- +# Internal helpers +# ------------------------------------------------------------------- + + +def _resolve_period_key(periods: dict, time_period: int): + """Find the best matching key in a variable's period dict. + + Tries ``time_period`` (int), ``str(time_period)``, then falls back + to the first available key (handles ETERNITY and Period objects). + Returns ``None`` when *periods* is empty. + """ + if time_period in periods: + return time_period + s = str(time_period) + if s in periods: + return s + if periods: + return next(iter(periods)) + return None + + +def _split_data_into_entity_dfs( + data: Dict[str, dict], + system, + time_period: int, +) -> Dict[str, pd.DataFrame]: + """Split the data dict into per-entity DataFrames. + + Args: + data: Maps variable names to ``{period: array}`` dicts. + system: A PolicyEngine tax-benefit system. + time_period: Year to extract from each variable's period dict. + + Returns: + One DataFrame per entity, keyed by entity name. + Group entities are deduplicated by their ID column. + """ + entity_vars: Dict[str, list] = {e: [] for e in ENTITIES} + + for var_name in sorted(data.keys()): + if var_name in system.variables: + ek = system.variables[var_name].entity.key + if ek in entity_vars: + entity_vars[ek].append(var_name) + else: + entity_vars["household"].append(var_name) + + entity_dfs: Dict[str, pd.DataFrame] = {} + for entity in ENTITIES: + id_col = f"{entity}_id" + cols = {} + for var_name in entity_vars[entity]: + periods = data[var_name] + tp_key = _resolve_period_key(periods, time_period) + if tp_key is None: + continue + arr = periods[tp_key] + if hasattr(arr, "dtype") and arr.dtype.kind == "S": + arr = np.char.decode(arr, "utf-8") + cols[var_name] = arr + + if entity == "person": + for ref_entity in ENTITIES[1:]: + ref_col = f"person_{ref_entity}_id" + if ref_col in data: + periods = data[ref_col] + tp_key = _resolve_period_key(periods, time_period) + if tp_key is not None: + cols[ref_col] = periods[tp_key] + + if not cols: + continue + + df = pd.DataFrame(cols) + if entity != "person" and id_col in df.columns: + df = df.drop_duplicates(subset=[id_col]).reset_index(drop=True) + entity_dfs[entity] = df + + return entity_dfs + + +def _build_uprating_manifest( + data: Dict[str, dict], + system, +) -> pd.DataFrame: + """Build manifest of variable metadata for embedding in HDFStore. + + Args: + data: Maps variable names to ``{period: array}`` dicts. + system: A PolicyEngine tax-benefit system. + + Returns: + DataFrame with columns: variable, entity, uprating. + """ + records = [] + for var_name in sorted(data.keys()): + entity = ( + system.variables[var_name].entity.key + if var_name in system.variables + else "unknown" + ) + uprating = "" + if var_name in system.variables: + uprating = getattr(system.variables[var_name], "uprating", None) or "" + records.append( + { + "variable": var_name, + "entity": entity, + "uprating": uprating, + } + ) + return pd.DataFrame(records) + + +# ------------------------------------------------------------------- +# Serializers +# ------------------------------------------------------------------- + + +def save_h5(result: DatasetResult, output_base: str) -> str: + """Write variable-centric h5py file. + + Args: + result: The assembled dataset. + output_base: Path stem **without** file extension. + + Returns: + Path to the created ``.h5`` file. + """ + h5_path = str(output_base) + ".h5" + with h5py.File(h5_path, "w") as f: + for variable, periods in result.data.items(): + grp = f.create_group(variable) + for period, values in periods.items(): + grp.create_dataset(str(period), data=values) + + print(f"\nH5 saved to {h5_path}") + + with h5py.File(h5_path, "r") as f: + tp = str(result.time_period) + if "household_id" in f and tp in f["household_id"]: + n = len(f["household_id"][tp][:]) + print(f"Verified: {n:,} households in output") + if "person_id" in f and tp in f["person_id"]: + n = len(f["person_id"][tp][:]) + print(f"Verified: {n:,} persons in output") + if "household_weight" in f and tp in f["household_weight"]: + hw = f["household_weight"][tp][:] + print(f"Total population (HH weights): {hw.sum():,.0f}") + if "person_weight" in f and tp in f["person_weight"]: + pw = f["person_weight"][tp][:] + print(f"Total population (person weights): {pw.sum():,.0f}") + + return h5_path + + +def save_hdfstore(result: DatasetResult, output_base: str) -> str: + """Write entity-level Pandas HDFStore file. + + Args: + result: The assembled dataset. + output_base: Path stem **without** file extension. + + Returns: + Path to the created ``.hdfstore.h5`` file. + """ + hdfstore_path = str(output_base) + ".hdfstore.h5" + + entity_dfs = _split_data_into_entity_dfs( + result.data, result.system, result.time_period + ) + manifest_df = _build_uprating_manifest(result.data, result.system) + + print(f"\nSaving HDFStore to {hdfstore_path}...") + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=pd.errors.PerformanceWarning, + message=".*PyTables will pickle object types.*", + ) + with pd.HDFStore(hdfstore_path, mode="w") as store: + for entity_name, df in entity_dfs.items(): + df = df.copy() + for col in df.columns: + if df[col].dtype == object: + df[col] = df[col].astype(str) + store.put(entity_name, df, format="table") + + store.put("_variable_metadata", manifest_df, format="table") + store.put( + "_time_period", + pd.Series([result.time_period]), + format="table", + ) + + for entity_name, df in entity_dfs.items(): + print(f" {entity_name}: {len(df):,} rows, {len(df.columns)} cols") + print(f" manifest: {len(manifest_df)} variables") + print("HDFStore saved successfully!") + + return hdfstore_path diff --git a/policyengine_us_data/utils/huggingface.py b/policyengine_us_data/utils/huggingface.py index 9b1e48cb..20f96d0d 100644 --- a/policyengine_us_data/utils/huggingface.py +++ b/policyengine_us_data/utils/huggingface.py @@ -11,7 +11,6 @@ def download(repo: str, repo_filename: str, local_folder: str, version: str = None): - hf_hub_download( repo_id=repo, repo_type="model",