Source code for gdt.core.plot.model

# CONTAINS TECHNICAL DATA/COMPUTER SOFTWARE DELIVERED TO THE U.S. GOVERNMENT WITH UNLIMITED RIGHTS
#
# Contract No.: CA 80MSFC17M0022
# Contractor Name: Universities Space Research Association
# Contractor Address: 7178 Columbia Gateway Drive, Columbia, MD 21046
#
# Copyright 2017-2022 by Universities Space Research Association (USRA). All rights reserved.
#
# Developed by: William Cleveland and Adam Goldstein
#               Universities Space Research Association
#               Science and Technology Institute
#               https://sti.usra.edu
#
# Developed by: Daniel Kocevski
#               National Aeronautics and Space Administration (NASA)
#               Marshall Space Flight Center
#               Astrophysics Branch (ST-12)
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing permissions and limitations under the
# License.
#
import numpy as np
from .plot import GdtPlot, Histo, ModelData, ModelSamples
from .plot import PlotElementCollection as Collection
from .lib import *
from .defaults import *
import warnings
import matplotlib.pyplot as plt

__all__ = ['ModelFit']


[docs]class ModelFit(GdtPlot): """Class for plotting spectral fits. Parameters: fitter (:class:`~gdt.spectra.fitting.SpectralFitter`, optional): The spectral fitter view (str, optional): The plot view, one of 'counts', 'photon', 'energy' or 'nufnu'. Default is 'counts' resid (bool, optional): If True, plots the residuals in counts view. Default is True. **kwargs: Options to pass to :class:`~gdt.core.plot.plot.GdtPlot` """ colors = '#7F3C8D,#11A579,#3969AC,#F2B701,#E73F74,#80BA5A,#E68310,#008695,#CF1C90,#f97b72,#4b4b8f,#A5AA99'.split(',') """(list): A list of default plotting colors to cycle through""" _min_y = 1e-10 def __init__(self, fitter=None, canvas=None, view='counts', resid=True, interactive=True): warnings.filterwarnings("ignore", category=np.exceptions.VisibleDeprecationWarning) self._figure, axes = plt.subplots(2, 1, sharex=True, sharey=False, figsize=(5.7, 6.7), dpi=100, gridspec_kw={'height_ratios': [3,1]}) plt.subplots_adjust(hspace=0) self._ax = axes[0] self._resid_ax = axes[1] self._view = view self._fitter = None self._count_models = Collection() self._count_data = Collection() self._resids = Collection() self._spectrum_model = Collection() # plot data and/or background if set on init if fitter is not None: self.set_fit(fitter, resid=resid) if interactive: plt.ion() @property def count_data(self): """(:class:`~gdt.core.plot.plot.PlotElementCollection` of :class:`~gdt.core.plot.plot.ModelData`): The count data plot elements""" return self._count_data @property def count_models(self): """(:class:`~gdt.core.plot.plot.PlotElementCollection` of :class:`~gdt.core.plot.plot.Histo`): The count model plot elements""" return self._count_models @property def spectrum_model(self): """(:class:`~gdt.core.plot.plot.PlotElementCollection` of :class:`~gdt.core.plot.plot.ModelSamples`): The model spectrum sample elements""" return self._spectrum_model @property def residuals(self): """(:class:`~gdt.core.plot.plot.PlotElementCollection` of :class:`~gdt.core.plot.plot.ModelData`): The fit residual plot elements""" return self._resids @property def view(self): """(str): The current plot view""" return self._view
[docs] def count_spectrum(self): """Plot the count spectrum fit. """ self._view = 'counts' self._ax.clear() model_counts = self._fitter.model_count_spectrum() energy, chanwidths, data_counts, data_counts_err, ulmasks = \ self._fitter.data_count_spectrum() for i in range(self._fitter.num_sets): det = self._fitter.detectors[i] self._count_models.include(Histo(model_counts[i], self._ax, edges_to_zero=False, color=self.colors[i], alpha=1.0, label=det), name=det) self._count_data.include(ModelData(energy[i], data_counts[i], chanwidths[i], data_counts_err[i], self._ax, ulmask=ulmasks[i], color=self.colors[i], alpha=0.7, linewidth=0.9), name=det) self._ax.set_ylabel(r'Rate [count s$^{-1}$ keV$^{-1}$]') self._set_view() self._ax.legend()
[docs] def energy_spectrum(self, **kwargs): """Plot the energy spectrum model. Args: num_samples (int, optional): The number of sample spectra. Default is 100. plot_components (bool, optional): Set to False to only plot the overall model, not each component. Default is False. """ self._view = 'energy' self._plot_spectral_model(**kwargs) self._ax.set_ylabel(r'Energy Flux [ph cm$^{-2}$ s$^{-1}$]', fontsize=PLOTFONTSIZE)
[docs] def hide_residuals(self): """Hide the fit residuals. """ try: self._figure.delaxes(self._resid_ax) self._ax.xaxis.set_tick_params(which='both', labelbottom=True) self._ax.set_xlabel('Energy (keV)', fontsize=PLOTFONTSIZE) except: print('Residuals already hidden')
[docs] def nufnu_spectrum(self, **kwargs): """Plot the nuFnu spectrum model. Args: num_samples (int, optional): The number of sample spectra. Default is 100. plot_components (bool, optional): Set to False to only plot the overall model, not each component. Default is False. """ self._view = 'nufnu' self._plot_spectral_model(**kwargs) self._ax.set_ylabel(r'$\nu F_\nu$ [keV ph cm$^{-2}$ s$^{-1}$]', fontsize=PLOTFONTSIZE)
[docs] def photon_spectrum(self, **kwargs): """Plot the photon spectrum model. Args: num_samples (int, optional): The number of sample spectra. Default is 10. plot_components (bool, optional): Set to False to only plot the overall model, not each component. Default is False. """ self._view = 'photon' self._plot_spectral_model(**kwargs) self._ax.set_ylabel(r'Photon Flux [ph cm$^{-2}$ s$^{-1}$ keV$^{-1}$]', fontsize=PLOTFONTSIZE)
[docs] def set_fit(self, fitter, resid=False): """Set the fitter. If a fitter already exists, this triggers a replot of the fit. Args: fitter (:class:`~gdt.spectra.fitting.SpectralFitter`): The spectral fitter for which a fit has been performed resid (bool, optional): If True, plot the fit residuals """ self._fitter = fitter if self._view == 'counts': self.count_spectrum() if resid: self.show_residuals() else: self.hide_residuals() elif self._view == 'photon': self.photon_spectrum() elif self._view == 'energy': self.energy_spectrum() elif self._view == 'nufnu': self.nufnu_spectrum() else: pass
[docs] def show_residuals(self, sigma=True): """Show the fit residuals. Args: sigma (bool, optional): If True, plot the residuals in units of model sigma, otherwise in units of counts. Default is True. """ # if we don't already have residuals axis if len(self._figure.axes) == 1: self._figure.add_axes(self._resid_ax) # get the residuals energy, chanwidths, resid, resid_err = self._fitter.residuals(sigma=sigma) # plot for each detector/dataset ymin, ymax = ([], []) for i in range(self._fitter.num_sets): det = self._fitter.detectors[i] self._resids.include(ModelData(energy[i], resid[i], chanwidths[i], resid_err[i], self._resid_ax, color=self.colors[i], alpha=0.7, linewidth=0.9), name=det) ymin.append((resid[i]-resid_err[i]).min()) ymax.append((resid[i]+resid_err[i]).max()) # the zero line self._resid_ax.axhline(0.0, color='black') self._resid_ax.set_xlabel('Energy [kev]', fontsize=PLOTFONTSIZE) if sigma: self._resid_ax.set_ylabel('Residuals [sigma]', fontsize=PLOTFONTSIZE) else: self._resid_ax.set_ylabel('Residuals [counts]', fontsize=PLOTFONTSIZE) # we have to set the y-axis range manually, because the y-axis # autoscale is broken (known issue) in matplotlib for this situation ymin = np.min(ymin) ymax = np.max(ymax) self._resid_ax.set_ylim((1.0-np.sign(ymin)*0.1)*ymin, (1.0+np.sign(ymax)*0.1)*ymax)
def _plot_spectral_model(self, num_samples=100, plot_components=True): """Plot the spectral model by sampling from the Gaussian approximation to the parameters' posterior. Args: num_samples (int, optional): The number of sample spectra. Default is 100. """ # clean plot and hide residuals if any warnings.filterwarnings("ignore", category=UserWarning) self._ax.clear() self.hide_residuals() num_comp = self._fitter.num_components comps = self._fitter.function_components name = self._fitter.function_name self._spectrum_model = Collection() # if the number of model components is > 1, plot each one if (num_comp > 1) and (plot_components): energies, samples = self._fitter.sample_spectrum(which=self._view, num_samples=num_samples, components=True) for i in range(num_comp): model = ModelSamples(energies, samples[:,i,:], self._ax, label=comps[i], color=self.colors[i+1], alpha=0.1, lw=0.3) self._spectrum_model.include(model) samples = samples.sum(axis=1) else: # or just plot the function energies, samples = self._fitter.sample_spectrum(which=self._view, num_samples=num_samples) y_max = samples.max(axis=(1,0)) self._spectrum_model.include(ModelSamples(energies, samples, self._ax, label=name, color=self.colors[0], alpha=0.1, lw=0.3)) self._set_view() # fix the alphas for the legend legend = self._ax.legend() for lh in legend.legend_handles: lh.set_alpha(1) lh.set_linewidth(1.0) if self._ax.get_ylim()[0] < self._min_y: self._ax.set_ylim(self._min_y, 10.0*y_max) def _set_view(self): """Set the view properties """ self._ax.set_xlim(self._fitter.energy_range) self._ax.yaxis.set_tick_params(labelsize=PLOTFONTSIZE) self._ax.set_xscale('log') self._ax.set_yscale('log') self._ax.set_xlabel('Energy [kev]', fontsize=PLOTFONTSIZE)