{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Creating a new Source class\n\nExtending sncosmo with a custom type of Source.\n\nA ``Source`` is something that specifies a spectral timeseries as\na function of an arbitrary number of parameters. For example, the SALT2\nmodel has three parameters (``x0``, ``x1`` and ``c``) that determine a\nunique spectrum as a function of phase. The ``SALT2Source`` class implements\nthe behavior of the model: how the spectral time series depends on those\nparameters.\n\nIf you have a spectral timeseries model that follows the behavior of one of\nthe existing classes, such as ``TimeSeriesSource``, great! There's no need to\nwrite a custom class. However, suppose you want to implement a model that\nhas some new parameterization. In this case, you need a new class that\nimplements the behavior.\n\nIn this example, we implement a new type of source model. Our model is a linear\ncombination of two spectral time series, with a parameter ``w`` that\ndetermines the relative weight of the models.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nfrom scipy.interpolate import RectBivariateSpline\nimport sncosmo\n\n\nclass ComboSource(sncosmo.Source):\n\n    _param_names = ['amplitude', 'w']\n    param_names_latex = ['A', 'w']   # used in plotting display\n\n    def __init__(self, phase, wave, flux1, flux2, name=None, version=None):\n        self.name = name\n        self.version = version\n        self._phase = phase\n        self._wave = wave\n\n        # ensure that fluxes are on the same scale\n        flux2 = flux1.max() / flux2.max() * flux2\n        \n        self._model_flux1 = RectBivariateSpline(phase, wave, flux1, kx=3, ky=3)\n        self._model_flux2 = RectBivariateSpline(phase, wave, flux2, kx=3, ky=3)\n        self._parameters = np.array([1., 0.5])  # initial parameters\n\n    def _flux(self, phase, wave):\n        amplitude, w = self._parameters\n        return amplitude * ((1.0 - w) * self._model_flux1(phase, wave) +\n                            w * self._model_flux2(phase, wave))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "... and that's all that we need to define!: A couple class attributes\n(``_param_names`` and ``param_names_latex``, an ``__init__`` method,\nand a ``_flux`` method. The ``_flux`` method is guaranteed to be passed\nnumpy arrays for phase and wavelength.\n\nWe can now initialize an instance of this source from two spectral time\nseries:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "#Just as an example, we'll use some undocumented functionality in\n# sncosmo to download the Nugent Ia and 2p templates. Don't rely on this\n# the `DATADIR` object, or these paths in your code though, as these are\n# subject to change between version of sncosmo!\nfrom sncosmo.builtins import DATADIR\nphase1, wave1, flux1 = sncosmo.read_griddata_ascii(\n    DATADIR.abspath('models/nugent/sn1a_flux.v1.2.dat'))\nphase2, wave2, flux2 = sncosmo.read_griddata_ascii(\n    DATADIR.abspath('models/nugent/sn2p_flux.v1.2.dat'))\n\n# In our __init__ method we defined above, the two fluxes need to be on\n# the same grid, so interpolate the second onto the first:\nflux2_interp = RectBivariateSpline(phase2, wave2, flux2)(phase1, wave1)\n\nsource = ComboSource(phase1, wave1, flux1, flux2_interp, name='sn1a+sn2p')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can get a summary of the Source we created:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(source)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Get a spectrum at phase 10 for different parameters:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from matplotlib import pyplot as plt\n\nwave = np.linspace(2000.0, 10000.0, 500)\nfor w in (0.0, 0.2, 0.4, 0.6, 0.8, 1.0):\n    source.set(w=w)\n    plt.plot(wave, source.flux(10., wave), label='w={:3.1f}'.format(w))\n\nplt.legend()\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The w=0 spectrum is that of the Ia model, the w=1 spectrum is that of\nthe IIp model, while intermediate spectra are weighted combinations.\n\nWe can even fit the model to some data!\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = sncosmo.Model(source=source)\ndata = sncosmo.load_example_data()\nresult, fitted_model = sncosmo.fit_lc(data, model,\n                                      ['z', 't0', 'amplitude', 'w'],\n                                      bounds={'z': (0.2, 1.0),\n                                              'w': (0.0, 1.0)})\n\nsncosmo.plot_lc(data, model=fitted_model, errors=result.errors)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The fact that the fitted value of w is closer to 0 than 1 indicates that\nthe light curve looks more like the Ia template than the IIp template.\nThis is generally what we expected since the example data here was\ngenerated from a Ia template (although not the Nugent template!).\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}