Source code for pyxel.models.photon_collection.poppy

#  Copyright (c) European Space Agency, 2020.
#
#  This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
#  is part of this Pyxel package. No part of the package, including
#  this file, may be copied, modified, propagated, or distributed except according to
#  the terms contained in the file ‘LICENCE.txt’.

"""Poppy model."""

import logging
import textwrap
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Final, get_args

import numpy as np
import xarray as xr
from astropy.convolution import convolve_fft
from astropy.io import fits
from astropy.units import Quantity

from pyxel.detectors import Detector, WavelengthHandling
from pyxel.util import convert_unit

if TYPE_CHECKING:
    import poppy as op

GENERIC_ERROR_MESSAGE: Final[str] = (
    "To resolve this issue, you can use for example this input in the YAML configuration:\n"
    "arguments:\n"
    "  optical_system:\n"
    "    - item: CircularAperture\n"
    "      radius: 1.0     # radius in meter"
)


@dataclass
class CircularAperture:
    """Parameters for an ideal circular pupil aperture.

    Parameters
    ----------
    radius : Quantity
        Radius of the pupil, in meters.
    """

    radius: Quantity


@dataclass
class DeprecatedThinLens:
    """Parameters for an idealized thin lens.

    Parameters
    ----------
    nwaves : float
        The number of waves of defocus, peak to valley.
    radius : float
        Pupil radius, in meters, over which the Zernike defocus term should be computed
        such that rho = 1 at r = `radius`.
    """

    nwaves: float
    radius: float


@dataclass
class ThinLens:
    """Parameters for an idealized thin lens.

    Parameters
    ----------
    nwaves : float
        The number of waves of defocus, peak to valley.
    radius : Quantity
        Pupil radius, in meters, over which the Zernike defocus term should be computed
        such that rho = 1 at r = `radius`.
    reference_wavelength : Quantity
        Wavelength, in nm, at which that number of waves of defocus is specified.
    """

    nwaves: float
    radius: Quantity
    reference_wavelength: Quantity | None = None
    # center wavelength if not provided takes the middle


@dataclass
class SquareAperture:
    """Parameters for an ideal square pupil aperture.

    Parameters
    ----------
    size : Quantity
        side length of the square, in meters.
    """

    size: Quantity


@dataclass
class RectangleAperture:
    """Parameters for an ideal rectangular pupil aperture.

    Parameters
    ----------
    width : Quantity
        width of the rectangle, in meters.
    height : Quantity
        height of the rectangle, in meters.
    """

    width: Quantity
    height: Quantity


@dataclass
class HexagonAperture:
    """Parameters for an ideal hexagonal pupil aperture.

    Parameters
    ----------
    side : Quantity
        side length (and/or radius) of hexagon, in meters.
    """

    side: Quantity


@dataclass
class MultiHexagonalAperture:
    """Parameters for an hexagonaly segmented aperture.

    Parameters
    ----------
    side : Quantity
        side length (and/or radius) of hexagon, in meters.
    rings : integer
        The number of rings of hexagons to include, not counting the central segment
        (i.e. 2 for a JWST-like aperture, 3 for a Keck-like aperture, and so on)
    gap : Quantity
        Gap between adjacent segments, in meters.
    """

    side: Quantity
    rings: int
    gap: Quantity


@dataclass
class SecondaryObscuration:
    """Parameters to define the central obscuration of an on-axis telescope.

    The parameters include secondary mirror and supports.

    Parameters
    ----------
    secondary_radius : Quantity
        Radius of the circular secondary obscuration, in meters.
    n_supports : int
        Number of secondary mirror supports ("spiders"). These will be
        spaced equally around a circle.
    support_width : Quantity
        Width of each support, in meters.
    """

    secondary_radius: Quantity
    n_supports: int
    support_width: Quantity


@dataclass
class ZernikeWFE:
    """Parameters to define an optical element in terms of its Zernike components.

    Parameters
    ----------
    radius : Quantity
        Pupil radius, in meters, over which the Zernike terms should be
        computed such that rho = 1 at r = `radius`.
    coefficients : iterable of floats
        Specifies the coefficients for the Zernike terms, ordered
        according to the convention of Noll et al. JOSA 1976. The
        coefficient is in meters of optical path difference (not waves).
    aperture_stop : float
    """

    radius: Quantity
    coefficients: Sequence[float]
    aperture_stop: float


@dataclass
class SineWaveWFE:
    """Parameters to define a single sine wave ripple across the optic.

    Parameters
    ----------
    spatialfreq : Quantity
    amplitude : Quantity
    rotation : float
    """

    spatialfreq: Quantity
    amplitude: Quantity
    rotation: float


# Define a type alias
OpticalParameter = (
    CircularAperture
    | ThinLens
    | SquareAperture
    | RectangleAperture
    | HexagonAperture
    | MultiHexagonalAperture
    | SecondaryObscuration
    | ZernikeWFE
    | SineWaveWFE
)


def _create_optical_parameter(
    dct: Mapping, default_wavelength: Quantity | tuple[Quantity, Quantity]
) -> OpticalParameter:
    """Create an``OpticalParameter`` based on a dictionary.

    Parameters
    ----------
    dct : dict
        Dictionary to convert.
    default_wavelength : Union[Quantity, tuple[Quantity,Quantity]]
        Wavelength in nanometer.

    Returns
    -------
    OpticalParameter
        New parameters.
    """
    if "item" not in dct:
        raise KeyError(
            f"Missing keyword 'item'. Got: {dct!r}.\n{GENERIC_ERROR_MESSAGE}"
        )

    if dct["item"] == "CircularAperture":
        if "radius" not in dct:
            raise KeyError(
                "Missing parameter 'radius' for the optical element 'CircularAperture'."
            )

        return CircularAperture(radius=Quantity(dct["radius"], unit="m"))

    elif dct["item"] == "ThinLens":
        if "reference_wavelength" in dct:
            reference_wavelength = Quantity(dct["reference_wavelength"], unit="nm")

        elif isinstance(default_wavelength, Quantity):
            reference_wavelength = default_wavelength
        else:
            cut_on, cut_off = default_wavelength
            reference_wavelength = (cut_on + cut_off) / 2

        if "nwaves" not in dct or "radius" not in dct:
            raise KeyError(
                "Missing one of these parameters: 'nwaves', 'radius' "
                "for the optical element 'ThinLens'."
            )

        return ThinLens(
            nwaves=float(dct["nwaves"]),
            radius=Quantity(dct["radius"], unit="m"),
            reference_wavelength=reference_wavelength,
        )

    elif dct["item"] == "SquareAperture":
        if "size" not in dct:
            raise KeyError(
                "Missing parameter 'size' for the optical element 'SquareAperture'."
            )

        return SquareAperture(size=Quantity(dct["size"], unit="m"))

    elif dct["item"] == "RectangularAperture":
        if "width" not in dct or "height" not in dct:
            raise KeyError(
                "Missing one of these parameters: 'width', 'height' "
                "for the optical element 'RectangularAperture'."
            )

        return RectangleAperture(
            width=Quantity(dct["width"], unit="m"),
            height=Quantity(dct["height"], unit="m"),
        )

    elif dct["item"] == "HexagonAperture":
        if "side" not in dct:
            raise KeyError(
                "Missing parameter 'side' for the optical element 'HexagonAperture'."
            )

        return HexagonAperture(side=Quantity(dct["side"], unit="m"))

    elif dct["item"] == "MultiHexagonalAperture":
        if "side" not in dct or "rings" not in dct or "gap" not in dct:
            raise KeyError(
                "Missing one of these parameters: 'side', 'rings', 'gap' "
                "for the optical element 'MultiHexagonalAperture'."
            )

        return MultiHexagonalAperture(
            side=Quantity(dct["side"], unit="m"),
            rings=int(dct["rings"]),
            gap=Quantity(dct["gap"], unit="m"),
        )

    elif dct["item"] == "SecondaryObscuration":
        if (
            "secondary_radius" not in dct
            or "n_supports" not in dct
            or "support_width" not in dct
        ):
            raise KeyError(
                "Missing one of these parameters: 'secondary_radius', 'n_supports', 'support_width' "
                "for the optical element 'SecondaryObscuration'."
            )

        return SecondaryObscuration(
            secondary_radius=Quantity(dct["secondary_radius"], unit="m"),
            n_supports=int(dct["n_supports"]),
            support_width=Quantity(dct["support_width"], unit="m"),
        )  # cm

    elif dct["item"] == "ZernikeWFE":
        if (
            "radius" not in dct
            or "coefficients" not in dct
            or "aperture_stop" not in dct
        ):
            raise KeyError(
                "Missing one of these parameters: 'radius', 'coefficients', 'aperture_stop' "
                "for the optical element 'ZernikeWFE'."
            )

        if (
            not isinstance(dct["coefficients"], Sequence)
            or len(dct["coefficients"]) == 0
        ):
            raise ValueError(
                "Expecting a list of numbers for parameter 'coefficients'"
                "for the optical element 'ZernikeWFE'."
            )

        return ZernikeWFE(
            radius=Quantity(dct["radius"], unit="m"),
            coefficients=dct["coefficients"],  # list of floats
            aperture_stop=float(dct["aperture_stop"]),
        )  # bool

    elif dct["item"] == "SineWaveWFE":
        if "spatialfreq" not in dct or "amplitude" not in dct or "rotation" not in dct:
            raise KeyError(
                "Missing one of these parameters: 'spatialfreq', 'amplitude', 'rotation' "
                "for the optical element 'SineWaveWFE'."
            )

        return SineWaveWFE(
            spatialfreq=Quantity(dct["spatialfreq"], unit="1/m"),
            amplitude=Quantity(dct["amplitude"], unit="um"),
            rotation=float(dct["rotation"]),
        )
    else:
        valid_optical_elements: Sequence[str] = [
            repr(cls.__name__) for cls in get_args(OpticalParameter)
        ]
        msg = f"Unknown 'optical_element', expected values: {', '.join(valid_optical_elements)}. Got: {dct!r}"
        msg_lst: list[str] = textwrap.wrap(msg, drop_whitespace=False)
        raise KeyError("\n".join(msg_lst))


def create_optical_item(
    dct: Mapping,
    default_wavelength: Quantity | tuple[Quantity, Quantity],
) -> "op.OpticalElement":
    """Create a poppy ``OpticalElement``.

    Parameters
    ----------
    dct : dict
        Dictionary to convert.
    default_wavelength : Union[Quantity, tuple[Quantity,Quantity]]
        Wavelength in nanometer.

    Returns
    -------
    ``OpticalElement``
        A poppy ``OpticalElement``.
    """
    try:
        import poppy as op
    except ModuleNotFoundError as exc:
        raise ModuleNotFoundError(
            "Missing optional package 'poppy'.\n"
            "Please install it with 'pip install pyxel-sim[model]'"
            "or 'pip install pyxel-sim[all]' or 'pip install poppy'"
        ) from exc

    param: OpticalParameter = _create_optical_parameter(
        dct=dct,
        default_wavelength=default_wavelength,
    )

    if isinstance(param, CircularAperture):
        return op.CircularAperture(radius=param.radius)

    elif isinstance(param, ThinLens):
        return op.ThinLens(
            nwaves=param.nwaves,
            reference_wavelength=param.reference_wavelength,
            radius=param.radius,
        )

    elif isinstance(param, SquareAperture):
        return op.SquareAperture(size=param.size)

    elif isinstance(param, RectangleAperture):
        return op.RectangleAperture(width=param.width, height=param.height)

    elif isinstance(param, HexagonAperture):
        return op.HexagonAperture(side=param.side)

    elif isinstance(param, MultiHexagonalAperture):
        return op.MultiHexagonAperture(
            side=param.side,
            rings=param.rings,
            gap=param.gap,
        )

    elif isinstance(param, SecondaryObscuration):
        return op.SecondaryObscuration(
            secondary_radius=param.secondary_radius,
            n_supports=param.n_supports,
            support_width=param.support_width,
        )

    elif isinstance(param, ZernikeWFE):
        return op.ZernikeWFE(
            radius=param.radius,
            coefficients=param.coefficients,
            aperture_stop=param.aperture_stop,
        )

    elif isinstance(param, SineWaveWFE):
        return op.SineWaveWFE(
            spatialfreq=param.spatialfreq,
            amplitude=param.amplitude,
            rotation=param.rotation,
        )
    else:
        raise NotImplementedError(f"{param=}")


def calc_psf(
    wavelengths: Sequence[float],
    fov_arcsec: float,
    pixel_scale: Quantity,
    optical_elements: Sequence["op.OpticalElement"],
    apply_jitter: bool = False,
    jitter_sigma: float = 0.007,
) -> tuple[fits.PrimaryHDU, fits.PrimaryHDU]:
    """Calculate the point spread function for the given optical system.

    Parameters
    ----------
    wavelengths : sequence of float
        Wavelengths of incoming light in meters.
    fov_arcsec : float, optional
        Field Of View on detector plane in arcsec.
    pixel_scale : float
        Pixel scale on detector plane (arcsec/pixel).
        Defines sampling resolution of :term:`PSF`.
    optical_elements : list of OpticalElement
        List of optical elements to apply.
    apply_jitter : bool
        Defines whether jitter should be applied. Default = False.
    jitter_sigma : float
        Jitter sigma value in arcsec per axis, default is 0.007.

    Returns
    -------
    Tuple of two :term:`FITS`
        Tuple of psf and intermediate wavefronts.
    """
    try:
        import poppy as op
    except ModuleNotFoundError as exc:
        raise ModuleNotFoundError(
            "Missing optional package 'poppy'.\n"
            "Please install it with 'pip install pyxel-sim[model]'"
            "or 'pip install pyxel-sim[all]' or 'pip install poppy'"
        ) from exc

    class PyxelInstrument(op.instrument.Instrument):
        """Instrument class for Pyxel using poppy.instrument."""

        def __init__(
            self,
            pixel_scale: Quantity,
            optical_elements: Sequence["op.OpticalElement"],
            fov_arcsec: float = 2,
            name="PyxelInstrument",
        ):
            super().__init__(name=name)
            self._pixel_scale: Quantity = pixel_scale.to("arcsec/pix")
            self._optical_elements = optical_elements
            self._fov_arcsec = fov_arcsec

        def get_optical_system(
            self,
            fft_oversample=2,
            detector_oversample=None,
            fov_arcsec=None,
            fov_pixels=None,
            options=None,
        ):
            """Return an OpticalSystem instance corresponding to the instrument as currently configured.

            Parameters
            ----------
            fft_oversample : int
                Oversampling factor for intermediate plane calculations. Default is 2
            detector_oversample : int, optional
                By default the detector oversampling is equal to the intermediate calculation oversampling.
                If you wish to use a different value for the detector, set this parameter.
                Note that if you just want images at detector pixel resolution you will achieve higher fidelity
                by still using some oversampling (i.e. *not* setting `oversample_detector=1`) and instead rebinning
                down the oversampled data.
            fov_pixels : float
                Field of view in pixels. Overrides fov_arcsec if both set.
            fov_arcsec : float
                Field of view, in arcseconds. Default is 2
            options : dict
                Other arbitrary options for optical system creation

            Returns
            -------
            osys : poppy.OpticalSystem
                an optical system instance representing the desired configuration.
            """
            osys = op.OpticalSystem(npix=1000)  # default: 1024

            element: op.OpticalElement
            for element in self._optical_elements:
                osys.add_pupil(element)

            analysis_fov = self._pixel_scale * Quantity(10, unit="pix")
            osys.add_detector(
                pixelscale=self._pixel_scale,
                fov_arcsec=analysis_fov,
            )

            return osys

    output_fits: Sequence[fits.hdu.image.PrimaryHDU]
    wavefronts: Sequence[op.Wavefront]

    # Create Instrument
    instrument = PyxelInstrument(
        optical_elements=optical_elements,
        pixel_scale=pixel_scale,
        fov_arcsec=fov_arcsec,
    )

    instrument.pixelscale = pixel_scale

    if apply_jitter:
        instrument.options["jitter"] = "gaussian"
        instrument.options["jitter_sigma"] = (
            jitter_sigma  # in arcsec per axis, default 0.007
        )

    output_fits, wavefronts = instrument.calc_datacube(
        wavelengths=wavelengths,
        fov_arcsec=fov_arcsec,
        oversample=1,
    )

    return output_fits, wavefronts


def apply_convolution(data: np.ndarray, kernel: np.ndarray) -> np.ndarray:
    """Convolve an array in 2D or 3D.

    Parameters
    ----------
    data : ndarray
        2D or 3D Array to be convolved with kernel_2d.
    kernel : ndarray
        The convolution kernel.

    Returns
    -------
    ndarray
        A convolved array.
    """

    if kernel.ndim == 2:
        mean = np.mean(data)
    elif kernel.ndim == 3:
        integrated = kernel.sum(axis=0)
        mean = integrated.mean()
    else:
        raise NotImplementedError

    *_, num_rows, num_cols = kernel.shape

    assert num_rows == num_cols
    # resize kernel, if kernel size too big.
    if num_rows > 11:
        import skimage.transform as sk

        if kernel.ndim == 2:
            new_shape: tuple[int, ...] = (11, 11)
        elif kernel.ndim == 3:
            num_wavelengths, _, _ = kernel.shape
            new_shape = num_wavelengths, 11, 11

        resized_kernel = sk.resize(kernel, output_shape=new_shape, anti_aliasing=False)
        kernel = resized_kernel / resized_kernel.sum()

    array = convolve_fft(
        data,
        kernel=kernel,
        boundary="fill",
        fill_value=mean,
    )

    return array


# ruff: noqa: C901
[docs] def optical_psf( detector: Detector, fov_arcsec: float, optical_system: Sequence[Mapping[str, Any]], wavelength: float | tuple[float, float] | None = None, pixel_scale: float | None = None, apply_jitter: bool = False, jitter_sigma: float = 0.007, extract_psf: bool = False, ) -> None: """Model function for poppy optics model: convolve photon array with psf. Parameters ---------- detector : Detector Pyxel Detector object. fov_arcsec : float Field Of View on detector plane in arcsec. optical_system : list of dict List of optical elements before detector with their specific arguments. wavelength : Union[float, tuple[float, float], None] Wavelength of incoming light in meters, default is None. pixel_scale : float, Optional, default: None Pixel scale of detector in arcsec/pix. apply_jitter : bool, default: False Defines whether jitter should be applied, default is False. jitter_sigma : float Jitter sigma value in arcsec per axis, default is 0.007. extract_psf : bool, default: False Copy the computed PSF into the data bucket ``detector.data['/photon_collection/optical_psf/[name of the model]/psf']`` Notes ----- For more information, you can find examples here: * :external+pyxel_data:doc:`examples/models/scene_generation/tutorial_example_scene_generation` * :external+pyxel_data:doc:`tutorial/01_first_simulation` """ logging.getLogger("poppy").setLevel( logging.WARNING ) # TODO: Fix this. See issue #81 if fov_arcsec <= 0.0: raise ValueError( f"Expecting strictly positive value for 'fov_arcsec'. Got {fov_arcsec!r} " ) if not optical_system: raise ValueError( "Parameter 'optical_system' does not contain any optical element(s)." f"\n{GENERIC_ERROR_MESSAGE}" ) # get pixel scale either from detector geometry or from model input if pixel_scale is None: if detector.geometry._pixel_scale is None: raise ValueError( "Pixel scale is not defined. It must be either provided in the detector geometry " "or as model argument." ) pixel_scale_with_unit: Quantity = Quantity( detector.geometry.pixel_scale, unit="arcsec/pix", ) else: if pixel_scale <= 0: raise ValueError( f"Parameter 'pixelscale' must be strictly positive. Got: {pixel_scale}" ) pixel_scale_with_unit = Quantity(pixel_scale, unit="arcsec/pix") # get wavelength information either from detector environment or from model input if wavelength is None: # take wavelength input from detector.environment if isinstance(detector.environment._wavelength, float): selected_wavelength: Quantity | tuple[Quantity, Quantity] = Quantity( detector.environment.wavelength, unit="nm" ) elif isinstance(detector.environment._wavelength, WavelengthHandling): selected_wavelength = ( Quantity(detector.environment._wavelength.cut_on, unit="nm"), Quantity(detector.environment._wavelength.cut_off, unit="nm"), ) else: raise ValueError( "Wavelength is not defined. It must be either provided in the detector geometry " "or as model argument." ) else: if isinstance(wavelength, int | float): if wavelength <= 0: raise ValueError( "Parameter 'wavelength' must be strictly positive. " f"Got: {wavelength}" ) selected_wavelength = Quantity(wavelength, unit="nm") elif isinstance(wavelength, Sequence) and len(wavelength) == 2: cut_on, cut_off = wavelength selected_wavelength = ( Quantity(cut_on, unit="nm"), Quantity(cut_off, unit="nm"), ) if not (0 < cut_on < cut_off): raise ValueError( "'wavelength' must be increasing and strictly positive. " f"Got: {selected_wavelength!r}" ) # Create 'OpticalElement' from an input 'dict' optical_elements: Sequence["op.OpticalElement"] = [ create_optical_item(dct, default_wavelength=selected_wavelength) for dct in optical_system ] # Depending on Type calculate for 2D or 3D photon if isinstance(selected_wavelength, Quantity): if detector.photon.ndim != 2: raise ValueError( f"A 'detector.photon' 2D is expected. Got an '{detector.photon.ndim=}' " ) # 2D # Processing # Get a Point Spread Function psf_hdu: fits.PrimaryHDU psf_hdu, _ = calc_psf( wavelengths=[selected_wavelength.to("m").value], fov_arcsec=fov_arcsec, pixel_scale=pixel_scale_with_unit, optical_elements=optical_elements, apply_jitter=apply_jitter, jitter_sigma=jitter_sigma, ) psf_3d: np.ndarray = psf_hdu.data psf_2d: np.ndarray = psf_3d[0, :, :] if extract_psf and detector.is_first_readout: optical_elements_attrs: dict[str, str | int] = { "num_optical_elements": len(optical_elements) } for idx, element in enumerate(optical_elements): optical_elements_attrs[f"element_{idx}"] = str(element) model_name: str = detector.current_running_model_name general_attrs = { "model": model_name, "wavelength": str(selected_wavelength), "fov": str(Quantity(fov_arcsec, unit="arcsec")), "pixel_scale": str(pixel_scale_with_unit), "apply_jitter": apply_jitter, "jitter_sigma": jitter_sigma, } psf_info = xr.DataArray( psf_2d, dims=["y", "x"], coords={ "wavelength": xr.DataArray( selected_wavelength.value, attrs={"unit": selected_wavelength.unit}, ) }, attrs=general_attrs | optical_elements_attrs, ) detector.data[f"/photon_collection/optical_psf/{model_name}/psf"] = psf_info # Convolution new_array_2d: np.ndarray = apply_convolution( data=detector.photon.array, kernel=psf_2d, ) detector.photon.array = new_array_2d else: if detector.photon.ndim != 3: raise ValueError( f"A 'detector.photon' 3D is expected. Got an '{detector.photon.ndim=}' " ) # 3D min_wavelength, max_wavelength = selected_wavelength # Get current wavelengths (in nm) start_wavelength = Quantity(min_wavelength, unit="m") end_wavelength = Quantity(max_wavelength, unit="m") wavelengths_nm: Quantity = Quantity( detector.photon.array_3d["wavelength"], unit="nm" ) tolerance = Quantity(1e-7, unit="m") selected_wavelengths_nm: Quantity = wavelengths_nm[ np.logical_and( wavelengths_nm >= (start_wavelength - tolerance), wavelengths_nm <= (end_wavelength + tolerance), ) ] if selected_wavelengths_nm.size == 0: raise ValueError( f"The provided wavelength range ({min_wavelength:unicode}, {max_wavelength:unicode}) has " f"no overlap with the wavelengths from 'detector.photon.array_3d' " f"({wavelengths_nm[0]:unicode}, {wavelengths_nm[-1]:unicode})." ) # Processing # Get a Point Spread Function psf_hdu_3d: fits.PrimaryHDU psf_hdu_3d, _ = calc_psf( wavelengths=selected_wavelengths_nm.to("m").value, fov_arcsec=fov_arcsec, pixel_scale=pixel_scale_with_unit, optical_elements=optical_elements, apply_jitter=apply_jitter, jitter_sigma=jitter_sigma, ) # Convolution psf_3d = psf_hdu_3d.data if extract_psf and detector.is_first_readout: optical_elements_attrs = {"num_optical_elements": len(optical_elements)} for idx, element in enumerate(optical_elements): optical_elements_attrs[f"element_{idx}"] = str(element) start_wavelength, end_wavelength = selected_wavelength model_name = detector.current_running_model_name general_attrs = { "model": model_name, "wavelengths": f"From {start_wavelength} to {end_wavelength}", "fov": str(Quantity(fov_arcsec, unit="arcsec")), "pixel_scale": str(pixel_scale_with_unit), "apply_jitter": apply_jitter, "jitter_sigma": jitter_sigma, } psf_info = xr.DataArray( psf_3d, dims=["wavelength", "y", "x"], coords={ "wavelength": xr.DataArray( selected_wavelengths_nm.value, dims="wavelength", attrs={"unit": selected_wavelengths_nm.unit}, ) }, attrs=general_attrs | optical_elements_attrs, ) detector.data[f"/photon_collection/optical_psf/{model_name}/psf"] = psf_info new_array_3d: np.ndarray = apply_convolution( data=detector.photon.array_3d.to_numpy(), kernel=psf_3d, ) data_selected_wavelength = xr.DataArray( selected_wavelengths_nm.value, dims=["wavelength"], attrs={"wavelength": convert_unit(selected_wavelengths_nm.unit)}, ) array_3d = xr.DataArray( new_array_3d, dims=["wavelength", "y", "x"], coords={"wavelength": data_selected_wavelength}, ) detector.photon.array_3d = array_3d