import logging
import numpy as np
import pandas as pd
from re import sub
from pathlib import Path
from matplotlib import pyplot as plt, rc_context
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
from matplotlib.widgets import RadioButtons, SpanSelector, Slider
from matplotlib.ticker import NullLocator
from astropy.io import fits
from lime.io import load_frame, save_frame, LiMe_Error, check_file_dataframe
from lime.plotting.plots import Plotter, frame_mask_switch, save_close_fig_swicth, mplcursor_parser,\
determine_cube_images, load_spatial_mask, check_image_size, \
image_plot, spec_plot, spatial_mask_plot, _masks_plot, theme, line_band_plotter, spec_mask_plotter, \
line_band_scaler, spec_profile_plotter
from lime.tools import pd_get, unique_line_arr
from lime.transitions import label_decomposition, Line, check_continua_bands
_logger = logging.getLogger('LiMe')
def check_line_selection(spec, input_log, obj_bands, selected_by_default=True, **kwargs):
# Use the reference bands (by default the lines database) to compute the lines on the object
ref_params = {**kwargs, **{'automatic_grouping': False, 'fit_cfg':None,
'default_cfg_prefix': None, 'obj_cfg_prefix': None}}
ref_bands = spec.retrieve.lines_frame(**ref_params)
# Check if there is a physical bands file
file_bands = check_file_dataframe(input_log, verbose=False)
# Physical file has preference over ref bands and by default those are active
if file_bands is not None:
in_bands = file_bands
default_status = 1
# There is an input object bands and those are assumed as detected
elif obj_bands is not None:
in_bands = obj_bands
default_status = 1
# New bands are created and the user decides if detected or not
else:
in_bands = spec.retrieve.lines_frame(**kwargs)
default_status = 1 if selected_by_default else 0
# Extract the lines from each dataframe and only get the ones in common (without suffixes)
in_lines = in_bands.index.to_numpy()
ref_lines = ref_bands.index.to_numpy()
in_core = np.array([sub(r'_(b|m)$', '', line) for line in in_lines])
ref_core = np.array([sub(r'_(b|m)$', '', line) for line in ref_lines])
# Give priority to those in the input log
comb_lines, idx = np.unique(np.concatenate((in_core, ref_core)), return_index=True)
idx_in = idx[idx < in_core.size]
idx_ref = idx[idx >= in_core.size] - in_core.size
# Create empty log
labels_arr = np.concatenate((in_lines[idx_in], ref_lines[idx_ref]))
log = pd.DataFrame(index=labels_arr, columns=in_bands.columns)
# Fill the values
ref_columns = np.intersect1d(log.columns, ref_bands.columns)
log.loc[in_lines[idx_in], in_bands.columns] = in_bands.loc[in_lines[idx_in], in_bands.columns].to_numpy()
log.loc[ref_lines[idx_ref], ref_columns] = ref_bands.loc[ref_lines[idx_ref], ref_columns].to_numpy()
# Generate array with reference for reference
active_lines = np.zeros(comb_lines.size).astype(int)
active_lines[idx_in] = default_status
# Sort to restore the order lost with unique
if 'wavelength' in log.columns:
sorted_indexes = log['wavelength'].values.argsort()
else:
wave_arr = label_decomposition(log.index.to_numpy(), params_list=['wavelength'], verbose=False)
sorted_indexes = np.argsort(wave_arr[0])
# Use the sorted index to reorder the DataFrame
log = log.iloc[sorted_indexes]
active_lines = active_lines[sorted_indexes].astype(bool)
labels_arr = labels_arr[sorted_indexes]
# Set NaN entries in dataframe as None
if 'group_label' in log.columns:
idcs_nan = log.group_label.isnull()
log.loc[idcs_nan, 'group_label'] = 'none'
return log, labels_arr, active_lines
def load_redshift_table(file_address, column_name):
file_address = Path(file_address)
# Open the file
if file_address.is_file():
log = load_frame(file_address)
if not (column_name in log.columns):
_logger.info(f'No column "{column_name}" found in input dataframe, a new column will be added to the file')
# Load the file
else:
if file_address.parent.is_dir():
log = pd.DataFrame(columns=[column_name])
else:
raise LiMe_Error(f'The input log directory: {file_address.parent} \n does not exist')
return log
def save_redshift_table(object, redshift, file_address):
if redshift != 0:
filePath = Path(file_address)
if filePath.parent.is_dir():
# Create a new dataframe and save it
if not filePath.is_file():
df = pd.DataFrame(data=redshift, index=[object], columns=['redshift'])
# Replace or append to dataframe
else:
df = pd.read_csv(filePath, delim_whitespace=True, header=0, index_col=0)
df.loc[object, 'redshift'] = redshift
# Save back
with open(filePath, 'wb') as output_file:
string_DF = df.to_string()
output_file.write(string_DF.encode('UTF-8'))
else:
_logger.warning(f'Output redshift table folder does not exist at {file_address}')
return
def circle_band_label(current_label):
match current_label[-2:]:
case '_b':
return f'{current_label[:-2]}_m'
case '_m':
return current_label[:-2]
case _:
return f'{current_label}_b'
def save_or_clear_log(log, log_address, active_lines, log_parameters='all'):
if log_parameters is None:
log_parameters = ['wavelength', 'wave_vac', 'w1', 'w2', 'w3', 'w4', 'w5', 'w6', 'latex_label',
'units_wave', 'particle', 'transition', 'rel_int']
if np.sum(active_lines) == 0:
if log_address.is_file():
log_address.unlink()
else:
if log_address is not None:
save_frame(log_address, log.loc[active_lines], parameters=log_parameters)
else:
_logger.warning(r"Not output redshift log provided, the selection won't be stored")
return
class BandsInspection:
def __init__(self):
self.fig = None
self.ax_list = None
self.ax = None
self.y_scale = None
self.show_continua = None
self.line_list = None
self.active_lines = None
self.fname = None
self.line = None
self.log = None
self.mask = None
self.wave_plot = None
self.flux_plot = None
self.z_corr = None
self.idcs_mask = None
self.color_bg = {True: theme.colors['line_selected'],
False: theme.colors['line_removed']}
self.out_params = ["wavelength", "wave_vac", "w1", "w2", "w3", "w4", "w5", "w6",
"units_wave", "particle", "transition", "rel_int"]
return
[docs]
def bands(self, fname, bands=None, default_status=True, show_continua=False, y_scale='auto', n_cols=6, n_rows=None,
col_row_scale=(1, 0.5), fig_cfg=None, in_fig=None, maximize=False, **kwargs):
"""
Launch an interactive line-bands editor and save selections to a lines frame file.
This tool opens a plot grid, one per spectral line, allowing you to
**inspect** each region and **adjust the band central edges (w3–w4)** used for
measurements. Edits are written to ``fname`` (a bands table on disk). If you
run the editor again on the same file, **existing user selections are preserved**.
A reference lines frame with the candidate lines can be provided via ``bands`` the default LiMe bands database is used.
The function accepts the arguments from the :meth:`lime.Spectrum.retrieve.lines_frame` to adjust the default database
values.
If `show_continua=True`` the user can also adjust the continua bands region.
**Left-click and drag** within a subplot to adjust the wavelength band limits interactively.
A **middle-click** on a subplot to cycle through the line group types in the output bands file.
The available suffixes are: blended (``_b``), merged (``_m``), and single (no suffix). The line label in the
plot title updates accordingly.
A **right-click** adds/remove a line from the selection (exluded lines have a red background)
Parameters
----------
fname : str or pathlib.Path
Output file path where the edited bands table will be saved.
bands : pandas.DataFrame, str, or pathlib.Path, optional
Reference bands table (or path) providing initial band limits for each line.
If ``None``, the default LiMe bands database is used. See
:ref:`bands documentation <line-bands-doc>`.
default_status : bool, optional
Initial selection status to apply **only** when a line has no existing entry
in ``fname`` (i.e., on first creation). Default is ``True``.
show_continua : bool, optional
If ``True``, draw the continuum side bands in each panel. Default ``False``.
y_scale : {"auto", "linear", "log"}, optional
Flux scale for all panels. ``"auto"`` chooses a sensible scale per panel;
otherwise force Matplotlib’s ``"linear"`` or ``"log"``. Default ``"auto"``.
n_cols : int, optional
Number of columns in the grid. Default ``6``.
n_rows : int, optional
Number of rows in the grid. If ``None``, it is inferred from the number of lines.
col_row_scale : tuple of (float, float), optional
Multiplicative factors for panel width and height (rough layout scaling).
Default ``(1, 0.5)``.
fig_cfg : dict, optional
Matplotlib figure configuration (e.g., size, DPI). See `matplotlib.RcParams
<https://matplotlib.org/stable/api/matplotlib_configuration_api.html#matplotlib.RcParams>`_.
in_fig : matplotlib.figure.Figure, optional
Existing figure to plot into. If ``None``, a new figure is created.
maximize : bool, optional
If ``True``, maximize the window after rendering. Default ``False``.
**kwargs
Additional keyword arguments passed to :meth:`lime.Spectrum.retrieve.lines_frame`
Notes
-----
- Band edits (w1–w6) are written to ``fname`` immediately when saved/closed.
- If no changes are performed the output ``fname`` will not be created.
- Existing entries in ``fname`` are **not overwritten**; new lines inherit
their initial state from ``default_status``.
- **Left-click and drag** within a subplot to adjust the wavelength band limits.
- **Right-click** on a line subplot to include or exclude the line from the output bands file.
Excluded lines are displayed with a red background.
- **Middle-click** on a subplot to cycle through the line group types in the output bands file.
The available suffixes are: blended (``_b``), merged (``_m``), and single (no suffix).
The line label in the plot title updates accordingly.
Examples
--------
Start from the default database and save to a new file:
>>> spec.check.bands("my_bands.xlsx")
Initialize from an existing bands table and show the continua bands for the selection:
>>> spec.check.bands("session_bands.xlsx", bands="ref_bands.xlsx", show_bands=True)
Adjusting the initial ``bands`` generation with the arguments from the with the
:meth:`lime.Spectrum.retrieve.lines_frame` function
>>> gp_spec.check.bands(lineBandsFile, band_vsigma=100, n_sigma=4, instrumental_correction=True,
>>> map_band_vsigma={'H1_4861A': 200, 'H1_6563A': 200, 'O3_4959A': 250, 'O3_5007A': 250},
>>> fit_cfg=obs_cfg, ref_bands=osiris_gp_df_path, show_bands=True, maximize=True)
"""
# Declare the function attributes
self.y_scale = y_scale
self.show_continua = show_continua
# Check the address of the output line frame
if isinstance(fname, (str, Path)):
self.fname = Path(fname)
if not self.fname.parent.is_dir():
raise LiMe_Error(f'Input bands file directory does not exist ({self.fname.parent.as_posix()})')
# Establish the list of line and their status
self.log, self.line_list, self.active_lines = check_line_selection(self._spec, self.fname, bands,
default_status, **kwargs)
# Store spectrum
self.wave_plot, self.flux_plot, _, self.z_corr, self.idcs_mask = frame_mask_switch(self._spec, True)
# Check there are lines in selection
n_lines = self.log.index.size
if n_lines > 0:
# Compute the rows and columns number
if n_lines > n_cols:
n_rows = n_rows or int(np.ceil(n_lines / n_cols))
else:
n_cols, n_rows = n_lines, 1
n_grid = n_cols * n_rows
# User configuration overwrites default configuration
size_conf = {'figure.figsize': (n_cols * col_row_scale[0], n_rows * col_row_scale[1])}
size_conf = size_conf if fig_cfg is None else {**size_conf, **fig_cfg}
plt_cfg = theme.fig_defaults(size_conf, fig_type='grid')
# Launch the interative figure
with rc_context(plt_cfg):
# Figure structure
self.fig = plt.figure() if in_fig is None else in_fig
# grid_spec = self.fig.add_gridspec(2, 1, height_ratios=[1, 0.1])
grid_spec = self.fig.add_gridspec(1, 1)
gs_lines = grid_spec[0].subgridspec(n_rows, n_cols, hspace=0.5)
self.ax_list = gs_lines.subplots().flatten() if n_lines > 1 else [gs_lines.subplots()]
# Fill the plot axes
span_selector_dict = {}
for i in range(n_grid):
if i < n_lines:
self.line = self.line_list[i]
self.plot_line_BI(self.ax_list[i], self.line)
span_selector_dict[f'spanner_{i}'] = SpanSelector(self.ax_list[i], self.on_select_BI, button=1,
direction='horizontal', useblit=True,
props=dict(alpha=0.5, facecolor='tab:blue'))
else:
# Clear not filled axes
self.fig.delaxes(self.ax_list[i])
# Connecting the figure to the interactive widgets
self.fig.canvas.mpl_connect('button_press_event', self.on_click_BI)
self.fig.canvas.mpl_connect('axes_enter_event', self.on_enter_axes_BI)
# Show the image
save_close_fig_swicth(None, True, self.fig, maximise=maximize,
plot_check=True if in_fig is None else False)
else:
_logger.warning(f'No lines found in the lines mask for the object wavelentgh range')
return
def plot_line_BI(self, ax, line, scale_dict=theme.plt):
# Establish the limits for the line spectrum plot
mask = self.log.loc[line, 'w1':'w6'].to_numpy() * self.z_corr
idcs_band = np.searchsorted(self.wave_plot, mask)
# Review the bands edges
idcs_band = check_continua_bands(idcs_band, self.wave_plot, reset_w2_w5=True)
# Just the center region is adjusted
if self.show_continua:
idxL = idcs_band[2] - 10 if idcs_band[2] - 10 > 0 else 0
idxH = idcs_band[3] + 10 if idcs_band[3] + 10 < self.wave_plot.size - 1 else self.wave_plot.size - 1
# Center + continua
else:
idxL = idcs_band[0] - 5 if idcs_band[0] - 5 > 0 else 0
idxH = idcs_band[5] + 5 if idcs_band[5] + 5 < self.wave_plot.size - 1 else self.wave_plot.size - 1
# Plot the spectrum
ax.step(self.wave_plot[idxL:idxH]/self.z_corr, self.flux_plot[idxL:idxH]*self.z_corr, where='mid',
color=theme.colors['fg'], linewidth=scale_dict['spectrum_width'])
# Continuum bands
line_band_plotter(ax, self.wave_plot, self.flux_plot, self.z_corr, idcs_band, line, theme.colors,
show_adjacent=self.show_continua)
# Plot the masked pixels
spec_mask_plotter(ax, self.idcs_mask[idxL:idxH], self.wave_plot[idxL:idxH], self.flux_plot[idxL:idxH],
self.z_corr, self.log, line, theme.colors)
# Plot line location
wave_line = pd_get(self.log, line, 'wavelength')
if wave_line is not None:
ax.axvline(wave_line, linestyle='--', color='grey', linewidth=0.5)
# Background for selective line for selected lines
if self.active_lines[self.line_list == line][0]:
ax.set_facecolor(theme.colors['line_selected'])
else:
ax.set_facecolor(theme.colors['line_removed'])
# Scale the y axis
line_band_scaler(ax, self.flux_plot[idxL:idxH] * self.z_corr, 'auto')
# Formatting the figure
ax.set_title(line, pad=3)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.yaxis.set_minor_locator(NullLocator())
ax.get_xlim() # TODO without this one there is no plot
return
def on_select_BI(self, w_low, w_high):
# Check we are not just clicking on the plot
if w_low != w_high:
# Just the central bands
if self.show_continua is False:
self.log.at[self.line, 'w3'] = w_low
self.log.at[self.line, 'w4'] = w_high
# Move the other bands to avoid issues
idx_low, idx_high = np.searchsorted(self.wave_plot, (w_low, w_high))
self.log.at[self.line, 'w1'] = self.wave_plot[np.max((0, idx_low - 5))]
self.log.at[self.line, 'w2'] = self.wave_plot[np.max((0, idx_low - 1))]
self.log.at[self.line, 'w5'] = self.wave_plot[np.min((idx_high + 1, self.wave_plot.size -1))]
self.log.at[self.line, 'w6'] = self.wave_plot[np.min((idx_high + 5, self.wave_plot.size -1))]
# Central and adjacent bands
else:
# Correcting line band
if w_low > self.log.at[self.line, 'w2'] and w_high <self.log.at[self.line, 'w5']:
self.log.at[self.line, 'w3'] = w_low
self.log.at[self.line, 'w4'] = w_high
# Correcting blue band
elif w_low < self.log.at[self.line, 'w3'] and w_high < self.log.at[self.line, 'w3']:
self.log.at[self.line, 'w1'] = w_low
self.log.at[self.line, 'w2'] = w_high
# Correcting Red
elif w_low > self.log.at[self.line, 'w4'] and w_high > self.log.at[self.line, 'w4']:
self.log.at[self.line, 'w5'] = w_low
self.log.at[self.line, 'w6'] = w_high
# Removing line
elif w_low < self.log.at[self.line, 'w1'] and w_high > self.log.at[self.line, 'w6']:
print(f'\n-- The line {self.line} mask has been removed')
# Weird case
else:
_logger.info(f'Unsuccessful line selection: {self.line}: w_low: {w_low}, w_high: {w_high}')
# Save the log to the file
save_or_clear_log(self.log, self.fname, self.active_lines, self.out_params)
# Redraw the line measurement
self.ax.clear()
self.plot_line_BI(self.ax, self.line)
self.fig.canvas.draw()
return
def on_enter_axes_BI(self, event):
# Assign current line and axis
self.ax = event.inaxes
title = self.ax.get_title()
if title != '':
self.line = title
def on_click_BI(self, event):
if event.button in (2, 3):
# Update the line label
if event.button == 2:
idx = self.line_list == self.line
new_name = circle_band_label(self.line)
self.log.rename(index={self.line: new_name}, inplace=True)
self.line = new_name
self.line_list[idx] = new_name
# Remove group label and latex label since it cannot be restored
for entry in ['group_label', 'latex_label']:
if entry in self.log.columns:
self.log.loc[self.line, entry] = 'none'
# Update the line active status
if event.button == 3:
idx = self.line == self.line_list
self.active_lines[idx] = np.invert(self.active_lines[idx])
# Save the log to the file
save_or_clear_log(self.log, self.fname, self.active_lines, self.out_params)
# Plot the line selection with the new Background
self.ax.clear()
self.plot_line_BI(self.ax, self.line)
self.fig.canvas.draw()
return
class RedshiftInspection:
def __init__(self):
# Plot Attributes
self._fig = None
self._ax = None
self._AXES_CONF = None
self._spec_label = None
self._legend_handle = None
# Input data
self._obj_idcs = None
self._column_log = None
self._log_address = None
self._output_idcs = None
self._latex_array = None
self._waves_array = None
# User pointing
self._lineSelection = None
self._user_point = None
self._none_value = None
self._unknown_value = None
self._sample_object = None
def redshift(self, obj_idcs, reference_lines, output_file_log=None, output_idcs=None, redshift_column='redshift',
initial_z=None, none_value=np.nan, unknown_value=0.0, legend_handle='levels', maximize=False, title=None,
output_address=None, n_pixels=10, fig_cfg={}, ax_cfg={}, in_fig=None, **kwargs):
# Check if input tuple
if isinstance(obj_idcs, tuple):
obj_idcs = pd.MultiIndex.from_tuples([obj_idcs], names=self._sample.index.names)
# Assign the attributes
self._obj_idcs = obj_idcs if isinstance(obj_idcs, pd.MultiIndex) else self._sample.loc[obj_idcs].index
self._column_log = redshift_column
self._none_value = none_value
self._unknown_value = unknown_value
self._spec_label = "" if title is None else title
self._legend_handle = legend_handle
self._user_point = None
# Parameters for the load function
self._load_params = {**self._sample.load_params, **kwargs}
self._load_params['redshift'] = 0
# Output Log params
self._log_address = output_file_log
# Only save new redshift in input idx if None provided
if output_idcs is None:
self._output_idcs = self._obj_idcs
else:
self._output_idcs = output_idcs if isinstance(output_idcs, pd.MultiIndex) else self._sample.loc[output_idcs].index
# Check the redshift column exists
if self._column_log not in self._sample.frame.columns:
raise LiMe_Error(f'Redshift column "{redshift_column}" does not exist in the current sample log.')
# Use provided redshift value
if initial_z is not None:
redshift_pred = initial_z
else:
redshift_pred = self._sample.loc[self._obj_idcs, self._column_log].to_numpy()
redshift_pred = None if np.all(pd.isnull(redshift_pred)) else np.nanmean(redshift_pred)
# Create initial entry
self._compute_redshift(redshift_output=redshift_pred)
# Get the lines transitions and latex labels
reference_bands_df = check_file_dataframe(reference_lines)
if reference_bands_df is None:
raise LiMe_Error(f'Reference line log could not be read ({reference_lines})')
else:
if isinstance(reference_bands_df, pd.DataFrame):
reference_lines = reference_bands_df.index.to_numpy()
# Sort by wavelength
_waves_array, _latex_array = label_decomposition(reference_lines, params_list=('wavelength', 'latex_label'))
idcs_sorted = np.argsort(_waves_array)
self._waves_array, self._latex_array = _waves_array[idcs_sorted], _latex_array[idcs_sorted]
# Set the plot format where the user's overwrites the default
size_conf = {'figure.figsize': (10, 6), 'axes.labelsize': 12, 'xtick.labelsize': 10, 'ytick.labelsize': 10}
size_conf = size_conf if fig_cfg is None else {**size_conf, **fig_cfg}
PLT_CONF = theme.fig_defaults(size_conf)
self._AXES_CONF = theme.ax_defaults(ax_cfg, self._sample, fig_type=None)
# Create and fill the figure
with rc_context(PLT_CONF):
# Generate the figure object and figures
self._fig = plt.figure() if in_fig is None else in_fig
gs = GridSpec(nrows=1, ncols=2, figure=self._fig, width_ratios=[2, 0.5], height_ratios=[1])
self._ax = self._fig.add_subplot(gs[0])
self._ax.set(**self._AXES_CONF)
# Create the RadioButtons widget for the lines
buttoms_ax = self._fig.add_subplot(gs[1])
labels_buttons = [r'$None$'] + list(self._latex_array) + [r'$Unknown$']
radio_props = {'s': [10] * len(labels_buttons)}
label_props = {'fontsize': [6] * len(labels_buttons)}
radio = RadioButtons(buttoms_ax, labels_buttons, radio_props=radio_props, label_props=label_props)
# Plot the spectrum
self._launch_plots_ZI()
# Connect the widgets
radio.on_clicked(self._button_ZI)
self._fig.canvas.mpl_connect('button_press_event', self._on_click_ZI)
# Plot on screen unless an output address is provided
# save_close_fig_swicth(output_address, 'tight', self._fig, maximise=maximize)
save_close_fig_swicth(None, None, self._fig, maximise=maximize,
plot_check=True if in_fig is None else False)
return
def _launch_plots_ZI(self):
# Get redshift from log
redshift_pred = self._sample.loc[self._obj_idcs, self._column_log].to_numpy()
redshift_pred = None if np.all(pd.isnull(redshift_pred)) else np.nanmean(redshift_pred)
# Store the figure limits
xlim, ylim = self._ax.get_xlim(), self._ax.get_ylim()
# Redraw the figure
self._ax.clear()
self._plot_spectrum_ZI(self._ax)
self._plot_line_labels_ZI(self._ax, self._user_point, redshift_pred)
self._ax.legend(loc=4)
title = f'{self._spec_label} z calculation'
if redshift_pred not in [None, self._none_value, self._unknown_value]:
title += f', redshift = {redshift_pred:0.3f}'
self._ax.set_title(title)
# Reset axis format
if (xlim[0] != 0) and (xlim[0] != 1): # First time
self._ax.set_xlim(xlim)
self._ax.set_ylim(ylim)
self._ax.set(**self._AXES_CONF)
self._fig.canvas.draw()
return
def _plot_spectrum_ZI(self, ax):
# Loop through the objects
for i, obj_idx in enumerate(self._obj_idcs):
# Load the spectrum with a zero redshift
spec = self._sample.load_function(self._sample.frame, obj_idx, self._sample.file_address,
instrument=self._sample.instrument, **self._load_params)
# Plot on the observed frame with reshift = 0
wave_plot, flux_plot, err_plot, z_corr, idcs_mask = frame_mask_switch(spec, True)
# Plot the spectrum
ax.step(wave_plot/z_corr, flux_plot*z_corr, label=self._label_generator(obj_idx), where='mid',
linewidth=theme.plt['spectrum_width'])
# Plot the masked pixels
_masks_plot(ax, None, wave_plot, flux_plot, z_corr, spec.frame, idcs_mask, color_dict=theme.colors)
return
def _plot_line_labels_ZI(self, ax, click_coord, redshift_pred):
if (redshift_pred != 0) and (not pd.isnull(redshift_pred)):
wave_min, wave_max = None, None
for obj_idx in self._obj_idcs:
# Load the spectrum
spec = self._sample.load_function(self._sample.frame, obj_idx, self._sample.file_address,
instrument=self._sample.instrument, **self._load_params)
wavelength = spec.wave.data
wavelength = wavelength[~np.isnan(wavelength)]
if wave_min is None:
wave_min = wavelength[0]
else:
wave_min = wavelength[0] if wavelength[0] < wave_min else wave_min
if wave_max is None:
wave_max = wavelength[-1]
else:
wave_max = wavelength[-1] if wavelength[-1] > wave_max else wave_max
# Check the lines which fit in the plot region
idcs_in_range = np.logical_and(self._waves_array * (1 + redshift_pred) >= wave_min,
self._waves_array * (1 + redshift_pred) <= wave_max)
# Plot lines in region
linesRange = self._waves_array[idcs_in_range]
latexRange = self._latex_array[idcs_in_range]
for i, lineWave in enumerate(linesRange):
if latexRange[i] == self._lineSelection:
color_line = 'tab:red'
else:
color_line = theme.colors['fg']
ax.axvline(x=lineWave * (1 + redshift_pred), color=color_line, linestyle='--', linewidth=0.5)
ax.annotate(latexRange[i], xy=(lineWave * (1 + redshift_pred), 0.85),
horizontalalignment="center",
rotation=90,
backgroundcolor='w',
size=6,
xycoords='data', xytext=(lineWave * (1 + redshift_pred), 0.85), textcoords=("data",
"axes fraction"),
bbox=dict(
facecolor=theme.colors['bg'], # Background color
edgecolor='none', # Border color
))
return
def _compute_redshift(self, redshift_output=None):
# Routine not to overwrite first measurement
if redshift_output is None:
# First time case: Both input but be provided
if self._lineSelection is not None:
# Default case nothing is selected:
if self._lineSelection == r'$None$':
_redshift_pred = self._none_value
elif self._lineSelection == r'$Unknown$':
_redshift_pred = self._unknown_value
# Wavelength selected
else:
if self._user_point is not None:
idx_line = self._latex_array == self._lineSelection
ref_wave = self._waves_array[idx_line][0]
_redshift_pred = self._user_point[0] / ref_wave - 1
# Special cases None == NaN, Unknown == 0
else:
_redshift_pred = self._none_value
else:
_redshift_pred = self._none_value
else:
_redshift_pred = redshift_output
# Store the new redshift
self._sample.loc[self._output_idcs, self._column_log] = _redshift_pred
# Save to file if provided
if self._log_address is not None:
save_frame(self._log_address, self._sample.frame)
return
def _button_ZI(self, line_selection):
# Button selection
self._lineSelection = line_selection
# Compute the redshift
self._compute_redshift()
# Replot the figure
self._launch_plots_ZI()
return
def _on_click_ZI(self, event, tolerance=3):
if event.button == 2:
self._user_point = (event.xdata, 0.5)
# Compute the redshift
self._compute_redshift()
# Replot the figure
self._launch_plots_ZI()
return
def _label_generator(self, idx_sample):
if self._legend_handle == 'levels':
spec_label =", ".join(map(str, idx_sample))
else:
if self._legend_handle in self._sample.index.names:
idx_item = list(self._sample.index.names).index(self._legend_handle)
spec_label = idx_sample[idx_item]
elif self._legend_handle in self._sample.frame.columns:
spec_label = self._sample.frame.loc[idx_sample, self._legend_handle]
else:
raise LiMe_Error(f'The input handle "{self._legend_handle}" is not found on the sample log columns')
return spec_label
class CubeInspection:
def __init__(self):
# Data attributes
self.grid_mesh = None
self.bg_image = None
self.fg_image = None
self.fg_levels = None
self.hdul_linelog = None
self.ext_log = None
self.spaxel_button = None
self.add_remove_button = None
self.spec = None
# Mask correction attributes
self.mask_file = None
self.mask_ext = None
self.masks_dict = {}
self.mask_color = None
self.mask_array = None
# Plot attributes
self.in_ax = None
self.axes_conf = {}
self.axlim_dict = {}
self.color_norm = None
self.mask_color_i = None
self.key_coords = None
self.marker = None
self.rest_frame = None
self.log_scale = None
self.restore_zoom = False
self.maintain_y_zoom = False
return
[docs]
def cube(self, line_bg, bands=None, line_fg=None, min_pctl_bg=60, cont_pctls_fg=(90, 95, 99), bg_cmap='gray',
fg_cmap='viridis', bg_norm=None, fg_norm=None, masks_file=None, masks_cmap='viridis_r', masks_alpha=0.2,
rest_frame=False, log_scale=False, fig_cfg=None, ax_cfg_image=None, ax_cfg_spec=None, in_fig=None,
fname=None, ext_frame_suffix='_LINELOG', maintain_y_zoom=True, wcs=None, spaxel_selection_button=1,
add_remove_button=3, maximize=False):
"""
Open an interactive cube viewer: image map (left) + spaxel spectrum (right).
The left panel shows a band-summed flux map for ``line_bg`` (bands from ``bands`` or
the default database). Optionally overlay **foreground contours** from ``line_fg``.
Clicking on a spaxel updates the right panel with that spaxel’s spectrum. If
a ``lines_file`` FITS log is provided, fitted profiles for the selected spaxel
are also displayed.
If a mask file is supplied (``masks_file``), a radio selector appears to toggle
which binary mask is active/visible. You can add/remove the currently selected
spaxel to/from the active mask using a configurable mouse button.
Parameters
----------
line_bg : str
Line label for the background image (band-summed flux). See :ref:`bands documentation <line-bands-doc>`.
bands : pandas.DataFrame, str, or pathlib.Path, optional
Bands table (or path). If ``None``, the default LiMe bands database is used.
line_fg : str, optional
Line label for foreground **contours**. If provided, contours are computed
from that line’s band-summed flux using ``cont_pctls_fg``.
min_pctl_bg : float, optional
Minimum percentile of the background band-summed flux used to set the lower
display limit when ``bg_norm`` is not supplied. Default is ``60``.
cont_pctls_fg : tuple of float, optional
Sorted percentiles for foreground contours. Default is ``(90, 95, 99)``.
bg_cmap : str, optional
Colormap name for the background image. Default is ``"gray"``.
fg_cmap : str, optional
Colormap name for the foreground contours. Default is ``"viridis"``.
bg_norm : matplotlib.colors.Normalize, optional
Normalization for the background image. If ``None``, a symmetric-log style
normalization is used (sensible defaults for wide dynamic ranges).
fg_norm : matplotlib.colors.Normalize, optional
Normalization for the foreground contours. If ``None``, a logarithmic
normalization is used.
masks_file : str or pathlib.Path, optional
Path to a FITS file with binary spatial masks to overlay and edit.
masks_cmap : str, optional
Colormap used to render masks. Default is ``"viridis_r"``.
masks_alpha : float, optional
Alpha transparency for mask overlays (0–1). Default is ``0.2``.
rest_frame : bool, optional
If ``True``, show the spectrum panel in rest-frame wavelengths. Default ``False``.
log_scale : bool, optional
If ``True``, plot the spectrum panel with a logarithmic flux scale. Default ``False``.
in_fig : matplotlib.figure.Figure, optional
Existing figure to plot into. If ``None``, a new figure is created.
fig_cfg : dict, optional
Matplotlib figure configuration (e.g., size, DPI). See `matplotlib.RcParams
<https://matplotlib.org/stable/api/matplotlib_configuration_api.html#matplotlib.RcParams>`_.
ax_cfg_image : dict, optional
Axes label/title overrides for the **image** panel; keys may include
``"xlabel"``, ``"ylabel"``, and ``"title"``.
ax_cfg_spec : dict, optional
Axes label/title overrides for the **spectrum** panel; keys may include
``"xlabel"``, ``"ylabel"``, and ``"title"``.
fname : str or pathlib.Path, optional
Path to a FITS lines-log file. If provided, fitted profiles for the selected
spaxel are displayed when available. Each spaxel page must be named
``"{j}-{i}{ext_frame_suffix}"`` (e.g., ``"25-30_LINELOG"``).
ext_frame_suffix : str, optional
Suffix used to match spaxel pages inside ``lines_file``. Default ``"_LINELOG"``.
maintain_y_zoom : bool, optional
If ``True`` (default), preserve the current y-axis zoom on the spectrum panel
when selecting different spaxels; if ``False``, autoscale on each selection.
wcs : astropy.wcs.WCS, optional
WCS to use for the image panel. If ``None``, the cube’s WCS is used when available.
See `Astropy WCS <https://docs.astropy.org/en/stable/wcs/index.html>`_.
spaxel_selection_button : int, optional
Mouse button used to **select/preview** a spaxel on the image panel.
Default is ``1`` (LEFT). Matplotlib button mapping: LEFT=1, MIDDLE=2, RIGHT=3,
BACK=8, FORWARD=9.
add_remove_button : int, optional
Mouse button used to **add/remove** the selected spaxel to/from the active mask.
Default is ``3`` (RIGHT). Matplotlib button mapping: LEFT=1, MIDDLE=2, RIGHT=3,
BACK=8, FORWARD=9.
maximize : bool, optional
If ``True``, maximize the window after rendering. Default ``False``.
Notes
-----
- **Interactivity:**
- Click the image panel with ``spaxel_selection_button`` to update the spectrum panel.
- Click the image panel with ``add_remove_button`` to toggle the selected spaxel in the active mask (when ``masks_file`` is provided).
- A round button selection changes the active mask.
- **Fitted profiles:** If :mod:`mplcursors` is installed and a fitted profile is
displayed, left-click shows fit parameters and right-click removes the tooltip.
- **Dynamic range tip:** With large flux dynamic ranges, percentile-based contour
levels may appear non-linear in visual spacing.
- **WCS handling:** If a valid WCS is provided (or available from the cube),
the image panel uses WCS projection and labeled axes.
Examples
--------
Basic interactive viewer with contours:
>>> cube.plot.cube("O3_5007A", line_fg="H1_6563A")
Use a masks file and RIGHT-click to add/remove spaxels from the active mask:
>>> cube.plot.cube("H1_6563A", masks_file="halpha_masks.fits", add_remove_button=3)
Show fitted profiles from a measurements log:
>>> cube.plot.cube("O3_5007A", fname="linelog.fits", ext_frame_suffix="_LINELOG")
"""
self.ext_log = ext_frame_suffix
self.mask_file = masks_file
self.spaxel_button = spaxel_selection_button
self.add_remove_button = add_remove_button
self.maintain_y_zoom = maintain_y_zoom
self.bg_color, self.fg_color = bg_cmap, fg_cmap
self.mask_color, self.mask_alpha = masks_cmap, masks_alpha
self.rest_frame, self.log_scale = rest_frame, log_scale
# Prepare the background image data
line_bg, self.bg_image, self.bg_levels, self.bg_scale = determine_cube_images(self._cube, line_bg, bands,
min_pctl_bg, bg_norm,
contours_check=False)
# Prepare the foreground image data
line_fg, self.fg_image, self.fg_levels, self.fg_scale = determine_cube_images(self._cube, line_fg, bands,
cont_pctls_fg, fg_norm,
contours_check=True)
# Mesh for the contours
self.fg_mesh = None if line_fg is None else np.meshgrid(np.arange(0, self.fg_image.shape[1]),
np.arange(0, self.fg_image.shape[0]))
# Load the masks
self.masks_dict = load_spatial_mask(self.mask_file)
self.mask_ext = list(self.masks_dict.keys())[0] if len(self.masks_dict) > 0 else self.ext_log
# Check that the images have the same size
check_image_size(self.bg_image, self.fg_image, self.masks_dict)
# Use the input wcs or use the parent one
wcs = self._cube.wcs if wcs is None else wcs
slices = None if wcs is None else ('x', 'y', 1) if wcs.naxis == 3 else ('x', 'y')
# Use central voxel as initial coordinate
self.key_coords = int(self._cube.flux.shape[1]/2), int(self._cube.flux.shape[2]/2)
# Load the complete fits lines log if input
if fname is not None:
if Path(fname).is_file():
self.hdul_linelog = fits.open(fname, lazy_load_hdus=False)
else:
_logger.info(f'The lines log at {fname} was not found.')
# Get figure configuration
fig_conf = theme.fig_defaults(fig_cfg, fig_type='cube_interactive')
self.axes_conf = {'image': theme.ax_defaults(ax_cfg_image, self._cube, fig_type='cube', line_bg=line_bg,
line_fg=line_fg, masks_dict=self.masks_dict, wcs=wcs),
'spectrum': theme.ax_defaults(ax_cfg_spec, self._cube, g_type='default', line_bg=line_bg,
line_fg=line_fg, masks_dict=self.masks_dict, wcs=wcs)}
grid_params = dict(nrows=1, ncols=2, width_ratios=[1, 2], height_ratios=[1])
sub_grid_params = dict(nrows=2, ncols=1, height_ratios=[0.7, 0.3])
# Create the figure
with rc_context(fig_conf):
# Figure structure
self.fig, self.ax = plt.figure() if in_fig is None else in_fig, [None, None, None]
gs = GridSpec(figure=self.fig, **grid_params)
sub_gs = gs if len(self.masks_dict) == 0 else GridSpecFromSubplotSpec(subplot_spec=gs[0], **sub_grid_params)
# Image and spectrum axes
self.ax[0] = self.fig.add_subplot(sub_gs[0]) if wcs is None else self.fig.add_subplot(sub_gs[0],
projection=wcs,
slices=slices)
self.ax[1] = self.fig.add_subplot(gs[1])
self.ax[2] = None if len(self.masks_dict) == 0 else self.fig.add_subplot(sub_gs[1])
# Buttons axis if provided
if self.ax[2] is not None:
radio = RadioButtons(self.ax[2],
labels=list(self.masks_dict.keys()),
radio_props={'s': [10] * len(self.masks_dict)},
label_props={'fontsize': [5] * len(self.masks_dict)})
radio.on_clicked(self.mask_selection)
# Plot the data
self.data_plots()
# Connect to the toolbar
self.toolbar = plt.get_current_fig_manager().toolbar
# Connect the widgets
self.fig.canvas.mpl_connect('axes_enter_event', self.on_enter_axes)
self.fig.canvas.mpl_connect('button_press_event', self.on_click)
self.fig.canvas.mpl_connect('button_release_event', self.click_zoom)
# Display the figure
save_close_fig_swicth(maximise=maximize, bbox_inches='tight', plot_check=True if in_fig is None else False)
# Close the lines log if it has been opened
if isinstance(self.hdul_linelog, fits.hdu.HDUList):
self.hdul_linelog.close()
return
def data_plots(self, show_profiles=True):
# Delete previous marker
if self.marker is not None:
self.marker.remove()
self.marker = None
# Background image
self.im, _, self.marker = image_plot(self.ax[0], self.bg_image, self.fg_image, self.fg_levels, self.fg_mesh,
self.bg_scale, self.fg_scale, self.bg_color, self.fg_color, self.key_coords)
# Spatial masks
spatial_mask_plot(self.ax[0], self.masks_dict, self.mask_color, self.mask_alpha, self._cube.units_flux,
mask_list=[self.mask_ext])
self.ax[0].update(self.axes_conf['image'])
# Voxel spectrum
spec = self.get_spaxel_spec()
wave_plot, flux_plot, err_plot, z_corr, idcs_mask = frame_mask_switch(spec, self.rest_frame)
# Plot the spectrum
self.ax[1].step(wave_plot / z_corr, flux_plot * z_corr, where='mid', color=theme.colors['fg'],
linewidth=theme.plt['spectrum_width'])
# Plot the fittings
if show_profiles and spec.frame.size > 0:
mplcursor_list = []
for line_label in unique_line_arr(spec.frame):
line = Line.from_transition(line_label, data_frame=spec.frame)
mplcursor_list += spec_profile_plotter(self.ax[1], spec, line, z_corr)
# Pop-ups
mplcursor_parser(mplcursor_list, spec)
# Y scale
if self.log_scale:
self.ax[1].set_yscale('log')
# Update the axis
self.axes_conf['spectrum']['title'] = f'Spaxel {self.key_coords[0]} - {self.key_coords[1]}'
self.ax[1].update(self.axes_conf['spectrum'])
return
def on_click(self, event, new_voxel_button=3):
if self.in_ax == self.ax[0]:
# Save axes zoom
self.save_zoom()
if event.button == self.spaxel_button:
# Save clicked coordinates for next plot
self.key_coords = np.rint(event.ydata).astype(int), np.rint(event.xdata).astype(int)
# Replot the figure
self.im.remove()
self.ax[1].clear()
self.data_plots()
self.reset_zoom()
self.fig.canvas.draw()
# if event.dblclick:
if event.button == self.add_remove_button:
if len(self.masks_dict) > 0:
# Save clicked coordinates for next plot
self.key_coords = np.rint(event.ydata).astype(int), np.rint(event.xdata).astype(int)
# Add or remove voxel from mask:
self.spaxel_selection()
# Save the new mask: # TODO just update the one we need
hdul = fits.HDUList([fits.PrimaryHDU()])
for mask_name, mask_attr in self.masks_dict.items():
hdul.append(fits.ImageHDU(name=mask_name, data=mask_attr[0].astype(int), ver=1, header=mask_attr[1]))
hdul.writeto(self.mask_file, overwrite=True, output_verify='fix')
# Replot the figure
self.im.remove()
self.ax[1].clear()
self.data_plots()
self.reset_zoom()
self.fig.canvas.draw()
return
def mask_selection(self, mask_label):
# Assign the mask
self.mask_ext = mask_label
# Zoom storage
self.save_zoom()
# Replot the figure
self.im.remove()
self.ax[1].clear()
self.data_plots()
self.reset_zoom()
self.fig.canvas.draw()
return
def spaxel_selection(self):
for mask, mask_data in self.masks_dict.items():
mask_matrix = mask_data[0]
if mask == self.mask_ext:
mask_matrix[self.key_coords[0], self.key_coords[1]] = not mask_matrix[self.key_coords[0], self.key_coords[1]]
else:
mask_matrix[self.key_coords[0], self.key_coords[1]] = False
self.masks_dict[mask] = mask_data
return
def on_enter_axes(self, event):
self.in_ax = event.inaxes
return
def save_zoom(self):
self.axlim_dict['image_xlim'] = self.ax[0].get_xlim()
self.axlim_dict['image_ylim'] = self.ax[0].get_ylim()
self.axlim_dict['spec_xlim'] = self.ax[1].get_xlim()
self.axlim_dict['spec_ylim'] = self.ax[1].get_ylim()
return
def reset_zoom(self):
if self.restore_zoom:
self.ax[0].set_xlim(self.axlim_dict['image_xlim'])
self.ax[0].set_ylim(self.axlim_dict['image_ylim'])
self.ax[1].set_xlim(self.axlim_dict['spec_xlim'])
if self.maintain_y_zoom:
self.ax[1].set_ylim(self.axlim_dict['spec_ylim'])
else:
self.ax[1].relim()
self.ax[1].autoscale_view()
return
def click_home(self):
self.restore_zoom = False
self.ax[1].relim()
self.ax[1].autoscale_view()
return
def click_zoom(self, event):
if self.in_ax == self.ax[1]:
if self.toolbar.mode == 'zoom rect' or self.toolbar.mode == 'pan/zoom':
self.restore_zoom = True
return
def get_spaxel_spec(self):
if self.key_coords is not None:
idx_j, idx_i = self.key_coords
spec = self._cube.get_spectrum(idx_j, idx_i)
# Check if lines have been measured
if self.hdul_linelog is not None:
ext_name = f'{idx_j}-{idx_i}{self.ext_log}'
# Better sorry than permission. Faster?
try:
log = pd.DataFrame.from_records(data=self.hdul_linelog[ext_name].data, index='index')
spec.load_frame(log)
except KeyError:
_logger.info(f'Extension {ext_name} not found in the input file')
return spec
else:
return None
class SpectrumCheck(Plotter, BandsInspection):
def __init__(self, spectrum):
# Instantiate the dependencies
Plotter.__init__(self)
BandsInspection.__init__(self)
# Lime spectrum object with the scientific data
self._spec = spectrum
# Variables for the matplotlib figures
self._fig, self._ax = None, None
return
class CubeCheck(Plotter, CubeInspection):
def __init__(self, cube):
# Instantiate the dependencies
Plotter.__init__(self)
CubeInspection.__init__(self)
# Lime cube object with the scientific data
self._cube = cube
# Variables for the matplotlib figures
self._fig, self._ax = None, None
return
class SampleCheck(Plotter, RedshiftInspection):
def __init__(self, sample):
# Instantiate the dependencies
Plotter.__init__(self)
RedshiftInspection.__init__(self)
# Lime spectrum object with the scientific data
self._sample = sample
# Variables for the matplotlib figures
self._fig, self._ax = None, None
return