Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.2.2] - Unreleased

- [#e42a4a73] QuasarNP internally now uses a WaveGrid object rather than numpy arrays
for wavelength grids. All code is backwards compatible.

## [0.2.1] - 2025-12-18
### Changed
Expand Down
36 changes: 26 additions & 10 deletions quasarnp/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np

from .model import QuasarNP
from .utils import rebin, renormalize, nbins, nbins_linear, wave, linear_wave
from .utils import rebin, renormalize, WaveGrid


def load_file(filename):
Expand All @@ -41,9 +41,22 @@ def load_file(filename):

try:
w_grid = f["model_grid"][:]

print(f"{w_grid=}")

log_grid = WaveGrid(linear=False)
linear_grid = WaveGrid(linear=True)
# Checking some defaults to correctly initialize the default grids.
if (len(w_grid) == len(log_grid)) and np.allclose(w_grid, log_grid.wave):
w_grid = WaveGrid(linear=False)
elif (len(w_grid) == len(linear_grid)) and np.allclose(w_grid, linear_grid.wave):
w_grid = WaveGrid(linear=True)
else:
w_grid = WaveGrid(grid=w_grid)

except KeyError:
print("Model grid not found in file, defaulting to logarithmic")
w_grid = wave
w_grid = WaveGrid(linear=False)

# Some versions of TF/Keras are 1 indexed and so bn layers start
# at batch_normalization_1. Some versions are 0 indexed and start at
Expand Down Expand Up @@ -347,7 +360,7 @@ def read_data(fi, truth=None, z_lim=2.1, return_pmf=False, nspec=None):

def load_desi_exposure(dir_name, spec_number,
fibers=np.ones(500, dtype="bool"),
out_grid=wave):
out_grid=WaveGrid(linear=False)):
"""Load and renormalize a raw DESI spectrographic exposure.

This method will load B, R and Z cframe files in sequence. First, spectra
Expand All @@ -368,7 +381,7 @@ def load_desi_exposure(dir_name, spec_number,
Array of length 500 indicating whether each fiber should be loaded.
True if the fiber should be loaded, False otherwise. Defaults to
True for all 500 fibers.
out_grid : numpy.ndarray, optional
out_grid : WaveGrid, optional
The wavelength grid to rebin the loaded exposure to. Defaults to the
logarithmic QuasarNET grid.

Expand Down Expand Up @@ -410,7 +423,7 @@ def load_desi_exposure(dir_name, spec_number,
# Load the flux and ivar
flux = h["FLUX"].read()[fibers, :]
ivar = h["IVAR"].read()[fibers, :]
w_grid = h["WAVELENGTH"].read()
w_grid = WaveGrid(grid=h["WAVELENGTH"].read())

# Rebin the flux and ivar
new_flux, new_ivar = rebin(flux, ivar, w_grid, out_grid=out_grid)
Expand All @@ -430,7 +443,7 @@ def load_desi_exposure(dir_name, spec_number,
return X_out, np.where(nonzero_weights)[0]


def load_desi_coadd(filename, rows=None, out_grid=wave):
def load_desi_coadd(filename, rows=None, out_grid=WaveGrid(linear=False)):
"""Load and renormalize a DESI coadded spectrographic exposure.

This method will load a coadd file and renormalize as follows. First,
Expand All @@ -450,7 +463,8 @@ def load_desi_coadd(filename, rows=None, out_grid=wave):
if the row should be loaded, False otherwise. Defaults to None, which
loads all rows.
out_grid : numpy.ndarray, optional
The wavelength grid to rebin the loaded exposure to.
The wavelength grid to rebin the loaded exposure to. Defaults to the
logarithmic QuasarNET grid.

Returns
-------
Expand Down Expand Up @@ -487,7 +501,7 @@ def load_desi_coadd(filename, rows=None, out_grid=wave):
# Load the flux and ivar
flux = h[fluxname].read()[rows, :]
ivar = h[ivarname].read()[rows, :]
w_grid = h[wname].read()
w_grid = WaveGrid(grid=h[wname].read())

# Rebin the flux and ivar
new_flux, new_ivar = rebin(flux, ivar, w_grid, out_grid=out_grid)
Expand All @@ -511,7 +525,7 @@ def load_desi_coadd(filename, rows=None, out_grid=wave):

def load_desi_daily(night, exp_id, spec_number,
fibers=np.ones(500, dtype="bool"),
w_grid=wave):
w_grid=WaveGrid(linear=False)):
"""Load and renormalize a daily DESI spectrographic exposure.

This method will load B, R and Z cframe files in sequence. First, spectra
Expand All @@ -535,7 +549,8 @@ def load_desi_daily(night, exp_id, spec_number,
True if the fiber should be loaded, False otherwise.
Defaults to True for all 500 fibers.
w_grid : numpy.ndarray, optional
The wavelength grid to rebin the loaded exposure to.
The wavelength grid to rebin the loaded exposure to. Defaults to the
logarithmic QuasarNET grid.

Returns
-------
Expand All @@ -560,6 +575,7 @@ def load_desi_daily(night, exp_id, spec_number,
# For now load daily cframes files
# TODO: add support for loading arbitrary cframes.
# TODO: Add support for loading by tile id + e rather than date + e
# TODO desispec.findfile instead.
root = "/global/cfs/cdirs/desi/spectro/redux/daily/exposures"
file_loc = Path(root, night, exp_id)

Expand Down
12 changes: 8 additions & 4 deletions quasarnp/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
import numpy as np

import quasarnp.io
from quasarnp.utils import wave, linear_wave
from quasarnp.utils import WaveGrid

file_loc = pathlib.Path(__file__).parent.resolve() / "test_files"


class TestLoadingModel(unittest.TestCase):
def setUp(self):
self.log_wave = WaveGrid(linear=False)
self.linear_wave = WaveGrid(linear=True)

def test_load_file(self):
# Get the location of this test script and load the test_weights file
# in this lower level directory.
Expand Down Expand Up @@ -64,7 +68,7 @@ def test_load_file(self):
observed = config_dict["conv_1"]["padding"]
self.assertEqual(observed, expected)

self.assertTrue(np.allclose(w_grid, wave))
self.assertTrue(np.allclose(w_grid.wave, self.log_wave.wave))

def test_load_linear_weights(self):
# Get the location of this test script and load the test_weights file
Expand All @@ -74,11 +78,11 @@ def test_load_linear_weights(self):
# This one should auto derive to log even though it's linear since
# we didn't post process to add the linear data.
*_, w_grid = quasarnp.io.load_file(loc)
self.assertTrue(np.allclose(w_grid, wave))
self.assertTrue(np.allclose(w_grid.wave, self.log_wave.wave))

loc = file_loc / "test_post_processed.h5"
*_, w_grid = quasarnp.io.load_file(loc)
self.assertTrue(np.allclose(w_grid, linear_wave))
self.assertTrue(np.allclose(w_grid.wave, self.linear_wave.wave))

class TestLoadingData(unittest.TestCase):
def test_load_desi_coadd(self):
Expand Down
30 changes: 17 additions & 13 deletions quasarnp/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@

import fitsio

from quasarnp.utils import regrid, process_preds, rebin, wave, linear_wave
from quasarnp.utils import regrid, process_preds, rebin, WaveGrid

file_loc = pathlib.Path(__file__).parent.resolve() / "test_files"


class TestUtilities(unittest.TestCase):
def setUp(self):
self.log_wave = WaveGrid(linear=False)
self.linear_wave = WaveGrid(linear=True)

# Test taking the old grid and generating which bins on the new grid
# the grid goes into.
def test_regrid_log(self):
# This is the regrids the grid to itself so shouldn't do anything
ob_bins, ob_keep = regrid(wave, wave)
ob_bins, ob_keep = regrid(self.log_wave, self.log_wave)
expected_bins = np.arange(443)

self.assertTrue(np.allclose(ob_bins, expected_bins))
Expand All @@ -25,7 +29,7 @@ def test_regrid_log(self):
# Testing regridding the DESI grid into the SDSS/QuasarNet grid.
wmin, wmax, wdelta = 3600, 9824, 0.8
old_grid = np.round(np.arange(wmin, wmax + wdelta, wdelta), 1)
ob_bins, ob_keep = regrid(old_grid, wave)
ob_bins, ob_keep = regrid(WaveGrid(grid=old_grid), self.log_wave)

# In order to not have to overload this file with nuisance, I have moved
# the actual answer here to regrid.txt. It's quite long, so only
Expand All @@ -40,17 +44,17 @@ def test_regrid_linear(self):
# This is the rebinned DESI grid, so regridding it shouldn't do anything.
# Linear DESI grid information
wmin, wmax, wdelta = 3600, 9824, 0.8
wdelta_qnet = wdelta * 17
wdelta_qnet = wdelta * 17 # For Linear Quasarnet grid
new_grid = np.round(np.arange(wmin, wmax + wdelta, wdelta_qnet), 1)

ob_bins, ob_keep = regrid(new_grid, linear_wave)
ob_bins, ob_keep = regrid(WaveGrid(grid=new_grid), self.linear_wave)
expected_bins = np.arange(458)
self.assertTrue(np.allclose(ob_bins, expected_bins))
self.assertTrue(np.allclose(ob_keep, np.ones_like(ob_keep, dtype=bool)))

# Testing regridding the DESI grid into the linear QuasarNet grid.
old_grid = np.round(np.arange(wmin, wmax + wdelta, wdelta), 1)
ob_bins, ob_keep = regrid(old_grid, linear_wave)
ob_bins, ob_keep = regrid(WaveGrid(grid=old_grid), self.linear_wave)

# 17 DESI bins per linear QuasarNET bin, but 17 * 458 is slightly
# longer than the true DESI grid, so the last bin only
Expand All @@ -62,7 +66,7 @@ def test_regrid_linear(self):
def test_regrid_arbitrary(self):
# Stephen Bailey's arbitrary grid
old_grid = np.arange(3600, 9800, 10)
ob_bins, ob_keep = regrid(old_grid, wave)
ob_bins, ob_keep = regrid(WaveGrid(grid=old_grid), self.log_wave)

# In order to not have to overload this file with nuisance, I have moved
# the actual answer here to regrid_arbitrary.txt. It's quite long, so only
Expand All @@ -74,11 +78,11 @@ def test_regrid_arbitrary(self):

def test_regrid_failure(self):
# Non constant binning should fail and raise a value error.
new_grid = np.concatenate([np.arange(3600, 4000, 10), np.arange(4000, 9800, 40)])
new_grid = WaveGrid(grid=np.concatenate([np.arange(3600, 4000, 10), np.arange(4000, 9800, 40)]))

# Testing regridding the DESI grid onto this broken grid
wmin, wmax, wdelta = 3600, 9824, 0.8
old_grid = np.round(np.arange(wmin, wmax + wdelta, wdelta), 1)
old_grid = WaveGrid(grid=np.round(np.arange(wmin, wmax + wdelta, wdelta), 1))

with self.assertRaises(ValueError):
_ = regrid(old_grid, new_grid)
Expand All @@ -103,10 +107,10 @@ def test_rebin(self):
# Load the flux and ivar
flux = h[fluxname].read()[:]
ivar = h[ivarname].read()[:]
w_grid = h[wname].read()
w_grid = WaveGrid(grid=h[wname].read())

# Rebin the flux and ivar
n_flux, n_ivar = rebin(flux, ivar, w_grid, out_grid=wave)
n_flux, n_ivar = rebin(flux, ivar, w_grid, out_grid=self.log_wave)

# Just checks that the rebinned is equal to the known
# "correct" rebinning
Expand All @@ -133,10 +137,10 @@ def test_rebin_linear(self):
# Load the flux and ivar
flux = h[fluxname].read()[:]
ivar = h[ivarname].read()[:]
w_grid = h[wname].read()
w_grid = WaveGrid(grid=h[wname].read())

# Rebin the flux and ivar
n_flux, n_ivar = rebin(flux, ivar, w_grid, out_grid=linear_wave)
n_flux, n_ivar = rebin(flux, ivar, w_grid, out_grid=self.linear_wave)

# Just checks that the rebinned is equal to the known
# "correct" rebinning
Expand Down
Loading