diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d044b3b..0543641 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -8,7 +8,7 @@ repos: - id: check-added-large-files - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.12.7" + rev: "v0.15.11" hooks: - id: ruff args: ["--fix", "--show-fixes"] diff --git a/racs_tools/beamcon_3D.py b/racs_tools/beamcon_3D.py index 58fcc6f..e5cb367 100644 --- a/racs_tools/beamcon_3D.py +++ b/racs_tools/beamcon_3D.py @@ -331,6 +331,9 @@ def _get_commonbeams( f"Expected target_beam to be type Beam or None, got {type(target_beam)}" ) + logger.info("Finding common beam for all channels and cubes") + logger.info(f"Number of channels {nchans=}") + if mode == "natural": big_beams = [] for n in trange( @@ -546,6 +549,7 @@ def _get_commonbeams( pa=commonbeams.pa * 0, ) + logger.info(f"Number of common beams: {len(commonbeams)}") return commonbeams @@ -765,10 +769,15 @@ def initfiles( ## Header spec_axis = wcs.spectral - crpix = int(spec_axis.wcs.crpix) + # account for either an float or array of single float. Anything else should fail! + crpix = ( + int(np.squeeze(spec_axis.wcs.crpix)) + if not np.isscalar(spec_axis.wcs.crpix) + else int(spec_axis.wcs.crpix) + ) nchans = spec_axis.array_shape[0] assert nchans == len(commonbeams), ( - "Number of channels in header and commonbeams do not match" + f"Number of channels {nchans=} in header and commonbeams {len(commonbeams)=} do not match" ) chans = np.arange(nchans) if ref_chan is None: diff --git a/racs_tools/convolve_uv.py b/racs_tools/convolve_uv.py index 0c90001..a77ee35 100644 --- a/racs_tools/convolve_uv.py +++ b/racs_tools/convolve_uv.py @@ -4,7 +4,7 @@ __author__ = "Wasim Raja" import gc -from typing import Literal, NamedTuple, Optional +from typing import Literal, NamedTuple import numpy as np import scipy.signal @@ -134,7 +134,7 @@ def convolve( new_beam: Beam, dx: u.Quantity, dy: u.Quantity, - cutoff: Optional[float] = None, + cutoff: float | None = None, ) -> ConvolutionResult: """Convolve by X-ing in the Fourier domain. - convolution with Gaussian kernels only @@ -248,7 +248,7 @@ def convolve_scipy( new_beam: Beam, dx: u.Quantity, dy: u.Quantity, - cutoff: Optional[float] = None, + cutoff: float | None = None, ) -> ConvolutionResult: """Convolve using scipy's convolution @@ -296,7 +296,7 @@ def convolve_astropy( new_beam: Beam, dx: u.Quantity, dy: u.Quantity, - cutoff: Optional[float] = None, + cutoff: float | None = None, ) -> ConvolutionResult: """Convolve using astropy's convolution @@ -347,7 +347,7 @@ def convolve_astropy_fft( new_beam: Beam, dx: u.Quantity, dy: u.Quantity, - cutoff: Optional[float] = None, + cutoff: float | None = None, ) -> ConvolutionResult: """Convolve using astropy's FFT convolution @@ -434,7 +434,7 @@ def get_convolving_beam( new_beam: Beam, dx: u.Quantity, dy: u.Quantity, - cutoff: Optional[float] = None, + cutoff: float | None = None, ) -> tuple[Beam, float]: """Get the beam to use for smoothing @@ -502,7 +502,7 @@ def smooth( dx: u.Quantity, dy: u.Quantity, conv_mode: Literal["robust", "scipy", "astropy", "astropy_fft"] = "robust", - cutoff: Optional[float] = None, + cutoff: float | None = None, ) -> np.ndarray: """Apply smoothing to image in Jy/beam diff --git a/racs_tools/getnoise_list.py b/racs_tools/getnoise_list.py index 3548b4d..b244955 100644 --- a/racs_tools/getnoise_list.py +++ b/racs_tools/getnoise_list.py @@ -3,7 +3,6 @@ import argparse import warnings -from typing import Union import astropy.units as u import numpy as np @@ -146,7 +145,7 @@ def main( blank: bool = False, cliplev: float = 5, iterate: int = 1, - outfile: Union[str, None] = None, + outfile: str | None = None, save_noise: bool = False, ) -> None: """Flag bad channels in Stokes Q and U cubes diff --git a/racs_tools/logging.py b/racs_tools/logging.py index f5e33ce..3179e3b 100644 --- a/racs_tools/logging.py +++ b/racs_tools/logging.py @@ -3,7 +3,6 @@ import logging import multiprocessing as mp from logging.handlers import QueueHandler, QueueListener -from typing import Optional logging.captureWarnings(True) @@ -12,7 +11,7 @@ def setup_logger( - filename: Optional[str] = None, + filename: str | None = None, ) -> tuple[logging.Logger, QueueListener, mp.Queue]: """Setup a logger diff --git a/tests/test_spectral.py b/tests/test_spectral.py new file mode 100644 index 0000000..705d7e8 --- /dev/null +++ b/tests/test_spectral.py @@ -0,0 +1,30 @@ +import numpy as np +import pytest +from astropy.io import fits +from astropy.wcs import WCS + + +@pytest.fixture +def spec_axis() -> WCS: + header = fits.Header() + header["NAXIS"] = 1 + header["NAXIS1"] = 288 + header["CTYPE1"] = "FREQ" + header["CRVAL1"] = 800e6 + header["CRPIX1"] = 1 + header["CDELT1"] = 1e6 + return WCS(header).spectral + + +def test_crpix(spec_axis: WCS) -> None: + crpix = ( + int(np.squeeze(spec_axis.wcs.crpix)) + if not np.isscalar(spec_axis.wcs.crpix) + else int(spec_axis.wcs.crpix) + ) + assert crpix == 1 + + +def test_nchan(spec_axis: WCS) -> None: + nchans = spec_axis.array_shape[0] + assert nchans == 288