{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b58d6954",
   "metadata": {},
   "source": [
    "# NIRCam PSF Photometry - Basic Example on stage-2 images"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e23e9b88",
   "metadata": {},
   "source": [
    "**Author**: Matteo Correnti, STScI Scientist II\n",
    "<br>\n",
    "**Last Updated**: June 12, 2022"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7df4f696",
   "metadata": {},
   "source": [
    "## Table of contents\n",
    "1. [Introduction](#intro)<br>\n",
    "2. [Setup](#setup)<br>\n",
    "    2.1 [Python imports](#py_imports)<br>\n",
    "    2.2 [Plotting functions imports](#matpl_imports)<br>\n",
    "    2.3 [PSF FWHM dictionary](#psf_fwhm)<br>\n",
    "3. [Import image to analyze](#data)<br>\n",
    "    3.1 [Display image](#display_data)<br>\n",
    "    3.2 [Convert image units and apply pixel area map](#convert_data)<br>\n",
    "4. [Create a synthetic PSF model (with WebbPSF)](#webbpsf_intro)<br>\n",
    "    4.1 [Create the single PSF](#single_webbpsf)<br>\n",
    "    4.2 [Display the single PSF](#display_single_webbpsf)<br>\n",
    "    4.3 [Create the grid of PSFs](#grid_webbpsf)<br>\n",
    "    4.4 [Display the grid of PSFs](#display_grid_webbpsf)<br>\n",
    "5. [Create the PSF model building an effective PSF](#epsf_intro)<br>\n",
    "    5.1 [Calculate the background](#bkg)<br>\n",
    "    5.2 [Find sources in the image](#find)<br>\n",
    "    5.3 [Select sources](#select)<br>\n",
    "    5.4 [Create catalog of selected sources](#create_cat)<br>\n",
    "    5.5 [Build the effective PSF](#build_epsf)<br>\n",
    "    5.6 [Display the effective PSF](#display_epsf)<br>\n",
    "6. [Perform PSF Photometry](#psf_phot)<br>\n",
    "    6.1 [PSF photometry output catalog](#psf_cat)<br>\n",
    "    6.2 [Display residual images](#residual)<br>\n",
    "7. [Exercise: perform PSF photometry on the other filter](#exercise)<br>\n",
    "8. [Bonus Part I: create your first NIRCam Color-Magnitude Diagram](#bonusI)<br>\n",
    "    8.1 [Load images and output catalogs](#load_data)<br>\n",
    "    8.2 [Cross-match PSF photometry catalogs](#cross_match)<br>\n",
    "    8.3 [Load input catalogs](#load_input)<br>\n",
    "    8.4 [Cross-match input catalogs](#cross_match_input)<br>\n",
    "    8.5 [Instrumental Color-Magnitude Diagram](#cmd)<br>\n",
    "9. [Bonus part II: create a grid of empirical PSFs](#bonusII)<br>\n",
    "    9.1 [Count stars in N x N grid](#count_stars)<br>\n",
    "    9.2 [Build effective PSF (single or grid)](#epsf_grid)<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f572688",
   "metadata": {},
   "source": [
    "1.<font color='white'>-</font>Introduction <a class=\"anchor\" id=\"intro\"></a>\n",
    "------------------"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95891849",
   "metadata": {},
   "source": [
    "**Data**: NIRCam simulated images obtained using [MIRAGE](https://jwst-docs.stsci.edu/jwst-other-tools/mirage-data-simulator) and run through the [JWST pipeline](https://jwst-pipeline.readthedocs.io/en/latest/) of the Large Magellanic Cloud (LMC) Astrometric Calibration Field. Simulations are obtained using a 4-pt subpixel dither for two couples of wide filters: F115W, and F200W for the SW channel, and F277W and F444W for the LW channel. We simulated only 1 NIRCam SW detector (i.e., \"NRCB1\"). \n",
    "\n",
    "For this example, we use 1 Level-2 image (.cal, calibrated but not rectified) for each of two SW filters (i.e., F115W and F200W) and derive the photometry in each one of them.\n",
    "\n",
    "PSF Photometry can be obtained using:\n",
    "\n",
    "* single PSF model obtained from WebbPSF\n",
    "* grid of PSF models from WebbPSF\n",
    "* single effective PSF (ePSF)\n",
    "* grid of effective PSF (bonus part II)\n",
    "\n",
    "The notebook shows:\n",
    "\n",
    "* how to obtain the PSF model from WebbPSF (or build an ePSF)\n",
    "* how to perform PSF photometry on the image\n",
    "* how to cross-match the catalogs of the different images (bonus part I)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cf3d18f",
   "metadata": {},
   "source": [
    "2.<font color='white'>-</font>Setup <a class=\"anchor\" id=\"setup\"></a>\n",
    "------------------"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd038c76",
   "metadata": {},
   "source": [
    "In this section we import all the necessary python packages and we define some plotting parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b762602",
   "metadata": {},
   "source": [
    "### 2.1<font color='white'>-</font>Python imports<a class=\"anchor\" id=\"py_imports\"></a> ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c50eace",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "if not os.environ.get('WEBBPSF_PATH'):\n",
    "    os.environ['WEBBPSF_PATH'] = '/data/webbpsf-data'\n",
    "\n",
    "import sys\n",
    "import time\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import glob as glob\n",
    "\n",
    "import urllib.request\n",
    "\n",
    "import tarfile\n",
    "\n",
    "from astropy.io import fits\n",
    "from astropy.visualization import simple_norm\n",
    "from astropy.nddata import NDData\n",
    "from astropy.modeling.fitting import LevMarLSQFitter\n",
    "from astropy.table import Table, QTable\n",
    "from astropy.coordinates import SkyCoord, match_coordinates_sky\n",
    "from astropy import units as u\n",
    "\n",
    "from photutils.background import MMMBackground, MADStdBackgroundRMS, Background2D\n",
    "from photutils.detection import DAOStarFinder\n",
    "from photutils import EPSFBuilder, GriddedPSFModel\n",
    "from photutils.psf import DAOGroup, extract_stars, IterativelySubtractedPSFPhotometry\n",
    "\n",
    "import jwst\n",
    "from jwst.datamodels import ImageModel\n",
    "\n",
    "import webbpsf\n",
    "from webbpsf.utils import to_griddedpsfmodel\n",
    "\n",
    "import pysynphot  # PYSIN_CDBS must be defined in the user's environment (see note below)\n",
    "\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35987e59",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-block alert-info\">\n",
    "\n",
    "**Note on pysynphot**: Data files for pysynphot are distributed separately by Calibration Reference Data System. They are expected to follow a certain directory structure under the root directory, identified by the PYSYN_CDBS environment variable that must be set prior to using this package. In the example below, the root directory is arbitrarily named /my/local/dir/trds/. \\\n",
    "export PYSYN_CDBS=/my/local/dir/trds/ \\\n",
    "See documentation [here](https://pysynphot.readthedocs.io/en/latest/#installation-and-setup) for the configuration and download of the data files.\n",
    "    \n",
    "<div >"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "956910c8",
   "metadata": {},
   "source": [
    "### 2.2<font color='white'>-</font>Plotting function imports<a class=\"anchor\" id=\"matpl_imports\"></a> ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04df7084-75fc-4a64-9742-67d6f41582d3",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-block alert-warning\">\n",
    "    <h3><u><b>Warning</b></u></h3>\n",
    "\n",
    "If the plots in this notebook don't render properly, you may need to install LaTeX. Find more information on a system-wide installation [here](https://www.latex-project.org/get/) and on a Jupyter-specific nbextension [here](https://jupyter-contrib-nbextensions.readthedocs.io/en/latest/nbextensions/latex_envs/README.html).\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dee46d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from matplotlib import style, pyplot as plt\n",
    "import matplotlib.patches as patches\n",
    "import matplotlib.ticker as ticker\n",
    "\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "\n",
    "plt.rcParams['image.cmap'] = 'viridis'\n",
    "plt.rcParams['image.origin'] = 'lower'\n",
    "plt.rcParams['axes.titlesize'] = plt.rcParams['axes.labelsize'] = 30\n",
    "plt.rcParams['xtick.labelsize'] = plt.rcParams['ytick.labelsize'] = 20\n",
    "\n",
    "font1 = {'family': 'helvetica', 'color': 'black', 'weight': 'normal', 'size': '12'}\n",
    "font2 = {'family': 'helvetica', 'color': 'black', 'weight': 'normal', 'size': '20'}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64814f46",
   "metadata": {},
   "source": [
    "### 2.3<font color='white'>-</font>PSF FWHM dictionary<a class=\"anchor\" id=\"psf_fwhm\"></a> ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f5a21ce",
   "metadata": {},
   "source": [
    "The dictionary contains the NIRCam point spread function (PSF) FWHM, from the [NIRCam Point Spread Function](https://jwst-docs.stsci.edu/near-infrared-camera/nircam-predicted-performance/nircam-point-spread-functions) JDox page. The FWHM are calculated from the analysis of the expected NIRCam PSFs simulated with [WebbPSF](https://www.stsci.edu/jwst/science-planning/proposal-planning-toolbox/psf-simulation-tool). \n",
    "\n",
    "**Note**: this dictionary will be updated once the values for the FWHM will be available for each detectors after commissioning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "176a76b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "filters = ['F070W', 'F090W', 'F115W', 'F140M', 'F150W2', 'F150W', 'F162M', 'F164N', 'F182M',\n",
    "           'F187N', 'F200W', 'F210M', 'F212N', 'F250M', 'F277W', 'F300M', 'F322W2', 'F323N',\n",
    "           'F335M', 'F356W', 'F360M', 'F405N', 'F410M', 'F430M', 'F444W', 'F460M', 'F466N', 'F470N', 'F480M']\n",
    "\n",
    "psf_fwhm = [0.987, 1.103, 1.298, 1.553, 1.628, 1.770, 1.801, 1.494, 1.990, 2.060, 2.141, 2.304, 2.341, 1.340,\n",
    "            1.444, 1.585, 1.547, 1.711, 1.760, 1.830, 1.901, 2.165, 2.179, 2.300, 2.302, 2.459, 2.507, 2.535, 2.574]\n",
    "\n",
    "dict_utils = {filters[i]: {'psf fwhm': psf_fwhm[i]} for i in range(len(filters))}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c7f81bd",
   "metadata": {},
   "source": [
    "3.<font color='white'>-</font>Import images to analyze<a class=\"anchor\" id=\"data\"></a>\n",
    "------------------"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03becf42",
   "metadata": {},
   "source": [
    "We load all the images and we create a dictionary that contains all of them, divided by detectors and filters. This is useful to check which detectors and filters are available and to decide if we want to perform the photometry on all of them or only on a subset (for example, only on the SW filters).\n",
    "\n",
    "We retrieve the NIRCam detector and filter from the image header. Note that for the LW channels, we transform the detector name derived from the header (**NRCBLONG**) to **NRCB5**. This will allow us to use the same variable name when creating a PSF using WebbPSF."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32374ae6",
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_images = {'NRCA1': {}, 'NRCA2': {}, 'NRCA3': {}, 'NRCA4': {}, 'NRCA5': {},\n",
    "               'NRCB1': {}, 'NRCB2': {}, 'NRCB3': {}, 'NRCB4': {}, 'NRCB5': {}}\n",
    "\n",
    "dict_filter_short = {}\n",
    "dict_filter_long = {}\n",
    "\n",
    "ff_short = []\n",
    "det_short = []\n",
    "det_long = []\n",
    "ff_long = []\n",
    "detlist_short = []\n",
    "detlist_long = []\n",
    "filtlist_short = []\n",
    "filtlist_long = []\n",
    "\n",
    "if not glob.glob('./*cal*fits'):\n",
    "\n",
    "    print(\"Downloading images\")\n",
    "\n",
    "    boxlink_images_lev2 = 'https://stsci.box.com/shared/static/dt0gm0lyvi1yfh942vad6a8gb5utd85m.gz'\n",
    "    boxfile_images_lev2 = './single_images_lev2.tar.gz'\n",
    "    urllib.request.urlretrieve(boxlink_images_lev2, boxfile_images_lev2)\n",
    "\n",
    "    tar = tarfile.open(boxfile_images_lev2, 'r')\n",
    "    tar.extractall()\n",
    "\n",
    "    images_dir = './'\n",
    "    images = sorted(glob.glob(os.path.join(images_dir, \"*cal.fits\")))\n",
    "\n",
    "else:\n",
    "\n",
    "    images_dir = './'\n",
    "    images = sorted(glob.glob(os.path.join(images_dir, \"*cal.fits\")))\n",
    "\n",
    "for image in images:\n",
    "\n",
    "    im = fits.open(image)\n",
    "    f = im[0].header['FILTER']\n",
    "    d = im[0].header['DETECTOR']\n",
    "\n",
    "    if d == 'NRCBLONG':\n",
    "        d = 'NRCB5'\n",
    "    elif d == 'NRCALONG':\n",
    "        d = 'NRCA5'\n",
    "    else:\n",
    "        d = d\n",
    "\n",
    "    wv = float(f[1:3])\n",
    "\n",
    "    if wv > 24:         \n",
    "        ff_long.append(f)\n",
    "        det_long.append(d)\n",
    "\n",
    "    else:\n",
    "        ff_short.append(f)\n",
    "        det_short.append(d)   \n",
    "\n",
    "    detlist_short = sorted(list(dict.fromkeys(det_short)))\n",
    "    detlist_long = sorted(list(dict.fromkeys(det_long)))\n",
    "\n",
    "    unique_list_filters_short = []\n",
    "    unique_list_filters_long = []\n",
    "\n",
    "    for x in ff_short:\n",
    "\n",
    "        if x not in unique_list_filters_short:\n",
    "\n",
    "            dict_filter_short.setdefault(x, {})\n",
    "\n",
    "    for x in ff_long:\n",
    "        if x not in unique_list_filters_long:\n",
    "            dict_filter_long.setdefault(x, {})   \n",
    "\n",
    "    for d_s in detlist_short:\n",
    "        dict_images[d_s] = dict_filter_short\n",
    "\n",
    "    for d_l in detlist_long:\n",
    "        dict_images[d_l] = dict_filter_long\n",
    "\n",
    "    filtlist_short = sorted(list(dict.fromkeys(dict_filter_short)))\n",
    "    filtlist_long = sorted(list(dict.fromkeys(dict_filter_long)))\n",
    "\n",
    "    if len(dict_images[d][f]) == 0:\n",
    "        dict_images[d][f] = {'images': [image]}\n",
    "    else:\n",
    "        dict_images[d][f]['images'].append(image)\n",
    "\n",
    "print(\"Available Detectors for SW channel:\", detlist_short)\n",
    "print(\"Available Detectors for LW channel:\", detlist_long)\n",
    "print(\"Available SW Filters:\", filtlist_short)\n",
    "print(\"Available LW Filters:\", filtlist_long)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c27cf3db",
   "metadata": {},
   "source": [
    "**Note**: in this particular example, we analyze each image separately to provide a general overview of the different steps necessary to perform PSF photometry and to highlight the different functions adopted in this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6c85339",
   "metadata": {},
   "outputs": [],
   "source": [
    "det = 'NRCB1'\n",
    "filt = 'F115W'\n",
    "\n",
    "im = fits.open(dict_images[det][filt]['images'][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9d2a770",
   "metadata": {},
   "source": [
    "### 3.1<font color='white'>-</font>Display the image<a class=\"anchor\" id=\"display_data\"></a> ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f012553",
   "metadata": {},
   "source": [
    "To check that our image does not present artifacts and can be used in the analysis, we display it. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "203b30fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 12))\n",
    "   \n",
    "data_sb = im[1].data\n",
    "\n",
    "ax = plt.subplot(1, 1, 1)\n",
    "\n",
    "plt.xlabel(\"X [px]\", fontdict=font2)\n",
    "plt.ylabel(\"Y [px]\", fontdict=font2)\n",
    "plt.title(filt, fontdict=font2)\n",
    "norm = simple_norm(data_sb, 'sqrt', percent=99.)\n",
    "\n",
    "ax.imshow(data_sb, norm=norm, cmap='Greys')\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58db4ab7",
   "metadata": {},
   "source": [
    "### 3.2<font color='white'>-</font>Convert image units and apply pixel area map<a class=\"anchor\" id=\"convert_data\"></a> ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10dac812",
   "metadata": {},
   "source": [
    "The unit of the Level-2 and Level-3 Images from the pipeline is MJy/sr (hence a surface brightness). The actual unit of the image can be checked from the header keyword **BUNIT**. The scalar conversion constant is copied to the header keyword **PHOTMJSR**, which gives the conversion from DN/s to megaJy/steradian. For our analysis we revert back to DN/s.\n",
    "\n",
    "For images that have not been transformed into a distortion-free frame (i.e. not drizzled), a correction must be applied to account for the different on-sky pixel size across the field of view. A pixel area map (PAM), which is an image where each pixel value describes that pixel's area on the sky relative to the native plate scale, is used for this correction. In the stage 2 of the JWST pipeline, the PAM is copied into an image extension called **AREA** in the science data product. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d952177",
   "metadata": {},
   "outputs": [],
   "source": [
    "imh = im[1].header\n",
    "data = data_sb / imh['PHOTMJSR']\n",
    "print('Conversion factor from {units} to DN/s for filter {f}:'.format(units=imh['BUNIT'], f=filt), imh['PHOTMJSR'])\n",
    "area = im[4].data\n",
    "data = data * area"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77f7e578",
   "metadata": {},
   "source": [
    "4.<font color='white'>-</font>Create synthetic PSF (with WebbPSF) <a class=\"anchor\" id=\"webbpsf_intro\"></a>\n",
    "------------------"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f0cb8a6",
   "metadata": {},
   "source": [
    "WebbPSF is a Python package that computes simulated PSFs for NASA’s JWST and Nancy Grace Roman Space Telescope (formerly WFIRST). WebbPSF transforms models of telescope and instrument optical state into PSFs, taking into account detector pixel scales, rotations, filter profiles, and point source spectra. It is not a full optical model of JWST, but rather a tool for transforming optical path difference (OPD) maps, created with some other tool, into the resulting PSFs as observed with JWST’s or Roman’s instruments. For a full documentation on WebbPSF, see [here](https://webbpsf.readthedocs.io/en/latest/) and for is capability and limitation, see [here](https://webbpsf.readthedocs.io/en/latest/intro.html).\n",
    "\n",
    "The function below allows to create a single PSF or a grid of PSFs (the PSF can also be saved as a fits file). First, we need to specify the instrument (NIRCam), detector, and filter. Then, to create a single (or grid) PSF we use the webbPSF method *psf_grid*, which will output a (list of or single) photutils GriddedPSFModel object(s). A tutorial notebook on the *psf_grid* method can be found [here](https://github.com/spacetelescope/webbpsf/blob/stable/notebooks/Gridded_PSF_Library.ipynb).\n",
    "\n",
    "**Important Parameters**:\n",
    "\n",
    "* `num`: the total number of fiducial PSFs to be created and saved in the files. This\n",
    "    number must be a square number (4, 9, 16, etc.)\n",
    "\n",
    "* `oversample`: the oversample factor we want to adopt in the PSF creation.\n",
    "\n",
    "* `fov`: the size in pixel of the PSF model. The size depends on the shape of the PSF and how much flux is contained in the wings of the PSFs (i.e., a small field of view will exclude more flux from the PSF wings). However, increasing the field of view, increase also the computational time, so we need to find a reasonable compromise.\n",
    "\n",
    "* `source`: the source spectrum we want to adopt. Source spectra are defined using the function  `webbpsf.specFromSpectralType` where we need to define the spectral type and the model library (e.g., webbpsf.specFromSepectralType('G5V', catalog='phoenix')). See also note below on the default spectrum depending on if pysynphot is installed or not. \n",
    "\n",
    "* `all_detectors`: run all detectors for the instrument. Since we analyze only 1 detector, we set `all_detectors = False` (we do not need to create a PSF model for all NIRCam detectors).\n",
    "\n",
    "* `use_detsampled_psf`: If the grid of PSFs returned will be detector sampled (made by binning down the oversampled PSF) or oversampled by the factor defined by the oversample. For our analysis, we want to create an oversampled PSF model, so we set `use_detsampled_psf = False`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a49f149b",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-block alert-info\">\n",
    "    \n",
    "**Note on centering**: by default, the PSF will be centered at the exact center of the output array. This means that if the PSF is computed on an array with an odd number of pixels, the PSF will be centered exactly on the central pixel. If the PSF is computed on an array with even size, it will be centered on the “crosshairs” at the intersection of the central four pixels.\n",
    "\n",
    "**Note on normalization**: by default, PSFs are normalized to total intensity = 1.0 at the entrance pupil (i.e. at the JWST OTE primary). A PSF calculated for an infinite aperture would thus have integrated intensity =1.0. A PSF calculated on any smaller finite subarray will have some finite encircled energy less than one.\n",
    "\n",
    "**Note on source spectrum**: The default source spectrum is, if *pysynphot* is installed, a G2V star spectrum from Castelli & Kurucz (2004). Without *pysynphot*, the default is a simple flat spectrum such that the same number of photons are detected at each wavelength.\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8f86e11",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_psf_model(det='NRCA1', filt='F070W', fov=101, source=None, create_grid=False, num=9, save_psf=False, \n",
    "                     detsampled=False):\n",
    "\n",
    "    nrc = webbpsf.NIRCam()\n",
    "\n",
    "    nrc.detector = det \n",
    "    nrc.filter = filt\n",
    "\n",
    "    print(\"Using a {field}\".format(field=fov), \"px fov\")\n",
    "\n",
    "    if create_grid:\n",
    "        print(\"\")\n",
    "        print(\"Creating a grid of PSF for filter {filt} and detector {det}\".format(filt=filt, det=det))\n",
    "        print(\"\")\n",
    "        num = num\n",
    "        \n",
    "        if save_psf:\n",
    "            \n",
    "            outname = 'PSF_%s_samp4_fov%d_npsfs%d.fits' % (filt, fov, num)\n",
    "            psf = nrc.psf_grid(num_psfs=num, oversample=4, source=source, all_detectors=False, fov_pixels=fov, \n",
    "                               save=True, outfile=os.path.join(psfs_dir,outname), use_detsampled_psf=detsampled)\n",
    "\n",
    "        else:\n",
    "        \n",
    "            psf = nrc.psf_grid(num_psfs=num, oversample=4, source=source, all_detectors=False, fov_pixels=fov, \n",
    "                               use_detsampled_psf=detsampled)\n",
    "            \n",
    "    else:\n",
    "        print(\"\")\n",
    "        print(\"Creating a single PSF for filter {filt} and detector {det}\".format(filt=filt, det=det))\n",
    "        print(\"\")\n",
    "        num = 1\n",
    "        \n",
    "        if save_psf:\n",
    "\n",
    "            outname = 'PSF_%s_samp4_fov%d_npsfs%d.fits' % (filt, fov, num)\n",
    "            psf = nrc.psf_grid(num_psfs=num, oversample=4, source=source, all_detectors=False, fov_pixels=fov, \n",
    "                               save=True, outfile=os.path.join(psfs_dir,outname), use_detsampled_psf=detsampled)\n",
    "        \n",
    "        else:\n",
    "      \n",
    "            psf = nrc.psf_grid(num_psfs=num, oversample=4, source=source, all_detectors=False, fov_pixels=fov, \n",
    "                               use_detsampled_psf=detsampled)\n",
    "        \n",
    "    return psf     "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cea02ed2",
   "metadata": {},
   "source": [
    "### 4.1<font color='white'>-</font>Create the single PSF<a class=\"anchor\" id=\"single_webbpsf\"></a> ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9bcf098",
   "metadata": {},
   "outputs": [],
   "source": [
    "psfs_dir = 'PSF_MODELS/'\n",
    "\n",
    "if not os.path.exists(psfs_dir):\n",
    "    os.makedirs(psfs_dir)\n",
    "\n",
    "psf_webbpsf_single = create_psf_model(det=det, filt=filt, fov=11, source=None, create_grid=False, save_psf=True, \n",
    "                                      detsampled=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c25d4b3",
   "metadata": {},
   "source": [
    "### 4.2<font color='white'>-</font>Display the single PSF<a class=\"anchor\" id=\"display_single_webbpsf\"></a> ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36b93483",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 12))\n",
    "\n",
    "ax = plt.subplot(1, 1, 1)\n",
    "\n",
    "norm_psf = simple_norm(psf_webbpsf_single.data[0], 'log', percent=99.)\n",
    "ax.set_title(filt, fontsize=40)\n",
    "ax.imshow(psf_webbpsf_single.data[0], norm=norm_psf)\n",
    "ax.set_xlabel('X [px]', fontsize=30)\n",
    "ax.set_ylabel('Y [px]', fontsize=30)\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbf3e59b",
   "metadata": {},
   "source": [
    "### 4.3<font color='white'>-</font>Create the grid of PSFs<a class=\"anchor\" id=\"grid_webbpsf\"></a> ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bedee6e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "psf_webbpsf_grid = create_psf_model(det=det, filt=filt, fov=11, source=None, create_grid=True, num=16, \n",
    "                                    save_psf=True, detsampled=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "910cfcf8",
   "metadata": {},
   "source": [
    "### 4.4<font color='white'>-</font>Display the grid of PSFs<a class=\"anchor\" id=\"display_grid_webbpsf\"></a> ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "849fcd0e",
   "metadata": {},
   "source": [
    "We show the grid of PSFs with their positions in detector coordinates and the difference from the mean to highlight the differences between the different models. We use the webbPSF function *gridded_library.display_psf_grid*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdc9ab78",
   "metadata": {},
   "outputs": [],
   "source": [
    "webbpsf.gridded_library.display_psf_grid(psf_webbpsf_grid, zoom_in=False, figsize=(14, 14))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3fab536",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "ax = plt.subplot(1, 1, 1)\n",
    "plt.title('Difference central PSF - corner PSF')        \n",
    "model1 = psf_webbpsf_single.data[0]\n",
    "model2 = psf_webbpsf_grid.data[0]\n",
    "\n",
    "ratio = (model1 - model2)\n",
    "aa = np.max(np.abs(ratio))\n",
    "divider = make_axes_locatable(ax)\n",
    "im = ax.imshow(ratio, origin='lower', vmin=-aa, vmax=aa, cmap='RdBu')\n",
    "cax = divider.append_axes(\"right\", size=\"5%\", pad=0.05)\n",
    "plt.colorbar(im, cax=cax)\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3a2eeb9",
   "metadata": {},
   "source": [
    "5.<font color='white'>-</font>Create PSF model building an effective PSF<a class=\"anchor\" id=\"epsf_intro\"></a>\n",
    "------------------"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b553a77",
   "metadata": {},
   "source": [
    "More information on the PhotUtils Effective PSF can be found [here](https://photutils.readthedocs.io/en/stable/epsf.html).\n",
    "\n",
    "The process of creating an effective PSF can be summarized as follows:\n",
    "\n",
    "* Find the stars in the image.\n",
    "* Select the stars we want to use for building the effective PSF. \n",
    "* Build the effective PSF."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3d727c8",
   "metadata": {},
   "source": [
    "### 5.1<font color='white'>-</font>Calculate the background<a class=\"anchor\" id=\"bkg\"></a> ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "733620fc",
   "metadata": {},
   "source": [
    "We adopted as Background estimator the function [MMMBackground](https://photutils.readthedocs.io/en/stable/api/photutils.background.MMMBackground.html#photutils.background.MMMBackground), which calculates the background in an array using the DAOPHOT MMM algorithm, on the whole image (The background is calculated using a mode estimator of the form `(3 * median) - (2 * mean)`). \n",
    "\n",
    "When dealing with a variable background and/or the need to mask the regions where we have no data (for example, if we are analyzing an image with all the 4 NIRCam SW detectors, i.e. containing the chip gaps), we can set `var_bkg = True` and use a more complex algorithm that takes into account those issues."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "513831fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_bkg(var_bkg=False):\n",
    "    \n",
    "    bkgrms = MADStdBackgroundRMS()\n",
    "    mmm_bkg = MMMBackground()\n",
    "\n",
    "    if var_bkg:\n",
    "        print('Using 2D Background')\n",
    "        sigma_clip = SigmaClip(sigma=3.)\n",
    "        coverage_mask = (data == 0)\n",
    "\n",
    "        bkg = Background2D(data, (100, 100), filter_size=(3, 3), sigma_clip=sigma_clip, bkg_estimator=mmm_bkg,\n",
    "                           coverage_mask=coverage_mask, fill_value=0.0)\n",
    "\n",
    "        data_bkgsub = data.copy()\n",
    "        data_bkgsub = data_bkgsub - bkg.background\n",
    "\n",
    "        _, _, std = sigma_clipped_stats(data_bkgsub)\n",
    "\n",
    "    else:\n",
    "\n",
    "        std = bkgrms(data)\n",
    "        bkg = mmm_bkg(data)\n",
    "\n",
    "        data_bkgsub = data.copy()\n",
    "        data_bkgsub -= bkg\n",
    "\n",
    "    return data_bkgsub, std"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "295a6048",
   "metadata": {},
   "source": [
    "### 5.2<font color='white'>-</font>Find sources in the image<a class=\"anchor\" id=\"find\"></a> ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "711b111d",
   "metadata": {},
   "source": [
    "To find sources in the image, we use the [DAOStarFinder](https://photutils.readthedocs.io/en/stable/api/photutils.detection.DAOStarFinder.html) function. \n",
    "\n",
    "[DAOStarFinder](https://photutils.readthedocs.io/en/stable/api/photutils.detection.DAOStarFinder.html) detects stars in an image using the DAOFIND ([Stetson 1987](https://ui.adsabs.harvard.edu/abs/1987PASP...99..191S/abstract)) algorithm. DAOFIND searches images for local density maxima that have a peak amplitude greater than `threshold` (approximately; threshold is applied to a convolved image) and have a size and shape similar to the defined 2D Gaussian kernel.\n",
    "\n",
    "**Important parameters**:\n",
    "\n",
    "* `threshold`: The absolute image value above which to select sources.\n",
    "* `fwhm`: The full-width half-maximum (FWHM) of the major axis of the Gaussian kernel in units of pixels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5667d21",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_stars(det='NRCA1', filt='F070W', threshold=3, var_bkg=False):\n",
    "    \n",
    "    print('Finding stars --- Detector: {d}, Filter: {f}'.format(f=filt, d=det))\n",
    "    \n",
    "    sigma_psf = dict_utils[filt]['psf fwhm']\n",
    "\n",
    "    print('FWHM for the filter {f}:'.format(f=filt), sigma_psf, \"px\")\n",
    "    \n",
    "    data_bkgsub, std = calc_bkg(var_bkg=False)\n",
    "    \n",
    "    daofind = DAOStarFinder(threshold=threshold * std, fwhm=sigma_psf)\n",
    "    found_stars = daofind(data_bkgsub)\n",
    "    \n",
    "    print('')\n",
    "    print('Number of sources found in the image:', len(found_stars))\n",
    "    print('-------------------------------------')\n",
    "    print('')\n",
    "    \n",
    "    return found_stars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f618171",
   "metadata": {},
   "outputs": [],
   "source": [
    "tic = time.perf_counter()\n",
    "\n",
    "found_stars = find_stars(det=det, filt=filt, threshold=10, var_bkg=False)\n",
    "\n",
    "toc = time.perf_counter()\n",
    "\n",
    "print(\"Elapsed Time for finding stars:\", toc - tic)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d222b0e4",
   "metadata": {},
   "source": [
    "### 5.3<font color='white'>-</font>Select sources<a class=\"anchor\" id=\"select\"></a> ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c5a3a77",
   "metadata": {},
   "source": [
    "We can adopt different methods to select sources we want to use to build an effective PSF. Here, we select objects applying a brightness cut (we do not want to include objects that are too faint) and using the `roundness2` and `sharpness` parameters provided in the [DAOStarFinder](https://photutils.readthedocs.io/en/stable/api/photutils.detection.DAOStarFinder.html) output catalog.\n",
    "\n",
    "`roundness2` measures the ratio of the difference in the height of the best fitting Gaussian function in x minus the best fitting Gaussian function in y, divided by the average of the best fitting Gaussian functions in x and y.\n",
    "\n",
    "`sharpness` measures the ratio of the difference between the height of the central pixel and the mean of the surrounding non-bad pixels in the convolved image, to the height of the best fitting Gaussian function at that point."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cda71d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 8))\n",
    "plt.clf()\n",
    "\n",
    "ax1 = plt.subplot(2, 1, 1)\n",
    "\n",
    "ax1.set_xlabel('mag', fontdict=font2)\n",
    "ax1.set_ylabel('sharpness', fontdict=font2)\n",
    "\n",
    "xlim0 = np.min(found_stars['mag']) - 0.25\n",
    "xlim1 = np.max(found_stars['mag']) + 0.25\n",
    "ylim0 = np.min(found_stars['sharpness']) - 0.15\n",
    "ylim1 = np.max(found_stars['sharpness']) + 0.15\n",
    "\n",
    "ax1.set_xlim(xlim0, xlim1)\n",
    "ax1.set_ylim(ylim0, ylim1)\n",
    "\n",
    "ax1.xaxis.set_major_locator(ticker.AutoLocator())\n",
    "ax1.xaxis.set_minor_locator(ticker.AutoMinorLocator())\n",
    "ax1.yaxis.set_major_locator(ticker.AutoLocator())\n",
    "ax1.yaxis.set_minor_locator(ticker.AutoMinorLocator())\n",
    "\n",
    "ax1.scatter(found_stars['mag'], found_stars['sharpness'], s=10, color='k')\n",
    "\n",
    "sh_inf = 0.78\n",
    "sh_sup = 0.92\n",
    "mag_lim = -4.0\n",
    "\n",
    "ax1.plot([xlim0, xlim1], [sh_sup, sh_sup], color='r', lw=3, ls='--')\n",
    "ax1.plot([xlim0, xlim1], [sh_inf, sh_inf], color='r', lw=3, ls='--')\n",
    "ax1.plot([mag_lim, mag_lim], [ylim0, ylim1], color='r', lw=3, ls='--')\n",
    "\n",
    "ax2 = plt.subplot(2, 1, 2)\n",
    "\n",
    "ax2.set_xlabel('mag', fontdict=font2)\n",
    "ax2.set_ylabel('roundness', fontdict=font2)\n",
    "\n",
    "ylim0 = np.min(found_stars['roundness2']) - 0.25\n",
    "ylim1 = np.max(found_stars['roundness2']) - 0.25\n",
    "\n",
    "ax2.set_xlim(xlim0, xlim1)\n",
    "ax2.set_ylim(ylim0, ylim1)\n",
    "\n",
    "ax2.xaxis.set_major_locator(ticker.AutoLocator())\n",
    "ax2.xaxis.set_minor_locator(ticker.AutoMinorLocator())\n",
    "ax2.yaxis.set_major_locator(ticker.AutoLocator())\n",
    "ax2.yaxis.set_minor_locator(ticker.AutoMinorLocator())\n",
    "\n",
    "round_inf = -0.40\n",
    "round_sup = 0.40\n",
    "\n",
    "ax2.scatter(found_stars['mag'], found_stars['roundness2'], s=10, color='k')\n",
    "\n",
    "ax2.plot([xlim0, xlim1], [round_sup, round_sup], color='r', lw=3, ls='--')\n",
    "ax2.plot([xlim0, xlim1], [round_inf, round_inf], color='r', lw=3, ls='--')\n",
    "ax2.plot([mag_lim, mag_lim], [ylim0, ylim1], color='r', lw=3, ls='--')\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd866ae6",
   "metadata": {},
   "source": [
    "### 5.4<font color='white'>-</font>Create catalog of selected sources<a class=\"anchor\" id=\"create_cat\"></a> ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9dc60860",
   "metadata": {},
   "source": [
    "We can also include a separation criteria if we want to retain in the final catalog only the stars that are well isolated. In particular, we can select only the stars that do not have a neighbour closer than X pixel, where X is a parameter that can be set manually.\n",
    "\n",
    "**Note**: The magnitude limit and the minimum distance to the closest neighbour depend on the user science case (i.e.; number of stars in the field of view, crowding, number of bright sources, minimum number of stars required to build the ePSF, etc.) and must be modified accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17087a44",
   "metadata": {},
   "outputs": [],
   "source": [
    "mask = ((found_stars['mag'] < mag_lim) & (found_stars['roundness2'] > round_inf)\n",
    "        & (found_stars['roundness2'] < round_sup) & (found_stars['sharpness'] > sh_inf) \n",
    "        & (found_stars['sharpness'] < sh_sup))\n",
    "\n",
    "found_stars_sel = found_stars[mask]\n",
    "found_stars_sel_f115w = found_stars_sel\n",
    "\n",
    "print('Number of stars selected to build ePSF:', len(found_stars_sel))\n",
    "\n",
    "# if we include the separation criteria:\n",
    "\n",
    "d = []\n",
    "\n",
    "# we do not want any stars in a 10 px radius. \n",
    "\n",
    "min_sep = 10\n",
    "\n",
    "x_tot = found_stars['xcentroid']\n",
    "y_tot = found_stars['ycentroid']\n",
    "\n",
    "for xx, yy in zip(found_stars_sel['xcentroid'], found_stars_sel['ycentroid']):\n",
    "\n",
    "    sep = []\n",
    "    dist = np.sqrt((x_tot - xx)**2 + (y_tot - yy)**2)\n",
    "    sep = np.sort(dist)[1:2][0]\n",
    "    d.append(sep)\n",
    "\n",
    "found_stars_sel['min distance'] = d\n",
    "mask_dist = (found_stars_sel['min distance'] > min_sep)\n",
    "\n",
    "found_stars_sel2 = found_stars_sel[mask_dist]\n",
    "found_stars_sel2_f115w = found_stars_sel2\n",
    "\n",
    "print('Number of stars selected to build ePSF \\\n",
    "including \"mimimum distance closest neighbour\" selection):', len(found_stars_sel2))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5238fb7f",
   "metadata": {},
   "source": [
    "### 5.5<font color='white'>-</font>Build the effective PSF<a class=\"anchor\" id=\"build_epsf\"></a> ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "299b3b7c",
   "metadata": {},
   "source": [
    "We Build the effective PSF using [EPSBuilder](https://photutils.readthedocs.io/en/stable/api/photutils.psf.EPSFBuilder.html#photutils.psf.EPSFBuilder) function.\n",
    "\n",
    "First, we exclude the objects for which the bounding box exceed the detector edge. Then, we extract cutouts of the stars using the [extract_stars()](https://photutils.readthedocs.io/en/stable/api/photutils.psf.extract_stars.html#photutils.psf.extract_stars) function. The size of the cutout is determined by the parameter `size` in our function *build_epsf*. Once we have the object containing the cutouts of our selected stars, we can build our ePSF using [EPSFBuilder](https://photutils.readthedocs.io/en/stable/api/photutils.psf.EPSFBuilder.html#photutils.psf.EPSFBuilder) class. \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "033be76d",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-block alert-info\">\n",
    "\n",
    "Here we limit the maximum number of iterations to 3 (to limit its run time), but in practice one should use about 10 or more iterations.\n",
    "    \n",
    "<div >"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "997de3d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_epsf(det='NRCA1', filt='F070W', size=11, found_table=None, oversample=4, iters=10):\n",
    "    \n",
    "    hsize = (size - 1) / 2\n",
    "    \n",
    "    x = found_table['xcentroid']\n",
    "    y = found_table['ycentroid']\n",
    "    \n",
    "    mask = ((x > hsize) & (x < (data.shape[1] - 1 - hsize)) & (y > hsize) & (y < (data.shape[0] - 1 - hsize)))\n",
    "\n",
    "    stars_tbl = Table()\n",
    "    stars_tbl['x'] = x[mask]\n",
    "    stars_tbl['y'] = y[mask]\n",
    "    \n",
    "    data_bkgsub, _ = calc_bkg()\n",
    "    \n",
    "    nddata = NDData(data=data_bkgsub)\n",
    "    stars = extract_stars(nddata, stars_tbl, size=size)\n",
    "\n",
    "    print('Creating ePSF --- Detector {d}, filter {f}'.format(f=filt, d=det))\n",
    "\n",
    "    epsf_builder = EPSFBuilder(oversampling=oversample, maxiters=iters, progress_bar=True)\n",
    "\n",
    "    epsf, fitted_stars = epsf_builder(stars)\n",
    "    \n",
    "    return epsf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f473acae",
   "metadata": {},
   "outputs": [],
   "source": [
    "epsf = build_epsf(det=det, filt=filt, size=11, found_table=found_stars_sel, oversample=4, iters=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6d002e7",
   "metadata": {},
   "source": [
    "### 5.6<font color='white'>-</font>Display the effective PSF<a class=\"anchor\" id=\"display_epsf\"></a> ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "064ab4fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 12))\n",
    "\n",
    "ax = plt.subplot(1, 1, 1)\n",
    "\n",
    "norm_epsf = simple_norm(epsf.data, 'log', percent=99.)\n",
    "plt.title(filt, fontsize=30)\n",
    "ax.imshow(epsf.data, norm=norm_epsf)\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "315795e8",
   "metadata": {},
   "source": [
    "6.<font color='white'>-</font>Perform PSF Photometry<a class=\"anchor\" id=\"psf_phot\"></a>\n",
    "------------------"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73165380",
   "metadata": {},
   "source": [
    "For general information on PSF Photometry with PhotUtils see [here](https://photutils.readthedocs.io/en/stable/psf.html). \n",
    "\n",
    "Photutils provides three classes to perform PSF Photometry: [BasicPSFPhotometry](https://photutils.readthedocs.io/en/stable/api/photutils.psf.BasicPSFPhotometry.html#photutils.psf.BasicPSFPhotometry), [IterativelySubtractedPSFPhotometry](https://photutils.readthedocs.io/en/stable/api/photutils.psf.IterativelySubtractedPSFPhotometry.html#photutils.psf.IterativelySubtractedPSFPhotometry), and [DAOPhotPSFPhotometry](https://photutils.readthedocs.io/en/stable/api/photutils.psf.DAOPhotPSFPhotometry.html#photutils.psf.DAOPhotPSFPhotometry). Together these provide the core workflow to make photometric measurements given an appropriate PSF (or other) model.\n",
    "\n",
    "[BasicPSFPhotometry](https://photutils.readthedocs.io/en/stable/api/photutils.psf.BasicPSFPhotometry.html#photutils.psf.BasicPSFPhotometry) implements the minimum tools for model-fitting photometry. At its core, this involves finding sources in an image, grouping overlapping sources into a single model, fitting the model to the sources, and subtracting the models from the image. In DAOPHOT parlance, this is essentially running the “FIND, GROUP, NSTAR, SUBTRACT” once.\n",
    "\n",
    "[IterativelySubtractedPSFPhotometry](https://photutils.readthedocs.io/en/stable/api/photutils.psf.IterativelySubtractedPSFPhotometry.html#photutils.psf.IterativelySubtractedPSFPhotometry) (adopted here) is similar to [BasicPSFPhotometry](https://photutils.readthedocs.io/en/stable/api/photutils.psf.BasicPSFPhotometry.html#photutils.psf.BasicPSFPhotometry), but it adds a parameter called `n_iters` which is the number of iterations for which the loop “FIND, GROUP, NSTAR, SUBTRACT, FIND…” will be performed. This class enables photometry in a scenario where there exists significant overlap between stars that are of quite different brightness. For instance, the detection algorithm may not be able to detect a faint and bright star very close together in the first iteration, but they will be detected in the next iteration after the brighter stars have been fit and subtracted. Like [BasicPSFPhotometry](https://photutils.readthedocs.io/en/stable/api/photutils.psf.BasicPSFPhotometry.html#photutils.psf.BasicPSFPhotometry), it does not include implementations of the stages of this process, but it provides the structure in which those stages run.\n",
    "\n",
    "**Important parameters**:\n",
    "\n",
    "* `finder`: classes to find stars in the image. We use [DAOStarFinder](https://photutils.readthedocs.io/en/stable/api/photutils.detection.DAOStarFinder.html).\n",
    "\n",
    "* `group_maker`:  clustering algorithm in order to label the sources according to groups. We use [DAOGroup](https://photutils.readthedocs.io/en/stable/api/photutils.psf.DAOGroup.html#photutils.psf.DAOGroup). The method group_stars divides an entire starlist into sets of distinct, self-contained groups of mutually overlapping stars. It accepts as input a list of stars and determines which stars are close enough to be capable of adversely influencing each others’ profile fits. [DAOGroup](https://photutils.readthedocs.io/en/stable/api/photutils.psf.DAOGroup.html#photutils.psf.DAOGroup) aceepts one parameter, `crit_separation`, which is the distance, in units of pixels, such that any two stars separated by less than this distance will be placed in the same group.\n",
    "\n",
    "* `fitter`: algorithm to fit the sources simultaneously for each group. We use an astropy fitter, [LevMarLSQFitter](https://docs.astropy.org/en/stable/api/astropy.modeling.fitting.LevMarLSQFitter.html#astropy.modeling.fitting.LevMarLSQFitter). \n",
    "\n",
    "* `niters`: number of iterations for which the \"psf photometry\" loop described above is performed.\n",
    "\n",
    "* `fitshape`: Rectangular shape around the center of a star which will be used to collect the data to do the fitting. \n",
    "\n",
    "* `aperture_radius`: The radius (in units of pixels) used to compute initial estimates for the fluxes of sources."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d2c4c4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def psf_phot(data=None, det='NRCA1', filt='F070W', th=2000, psf=None, ap_radius=3.5, save_residuals=False, \n",
    "             save_output=False):\n",
    "\n",
    "    fitter = LevMarLSQFitter()\n",
    "    mmm_bkg = MMMBackground()\n",
    "        \n",
    "    sigma_psf = dict_utils[filt]['psf fwhm']\n",
    "    print('FWHM for filter {f}:'.format(f=filt), sigma_psf)\n",
    "    \n",
    "    _, std = calc_bkg()\n",
    "    \n",
    "    daofind = DAOStarFinder(threshold=th * std, fwhm=sigma_psf)\n",
    "    \n",
    "    daogroup = DAOGroup(5.0 * sigma_psf)\n",
    "    \n",
    "    psf_model = psf.copy()\n",
    "    \n",
    "    print('Performing the PSF photometry --- Detector {d}, filter {f}'.format(f=filt, d=det))\n",
    "            \n",
    "    tic = time.perf_counter()\n",
    "    \n",
    "    phot = IterativelySubtractedPSFPhotometry(finder=daofind, group_maker=daogroup,\n",
    "                                              bkg_estimator=mmm_bkg, psf_model=psf_model,\n",
    "                                              fitter=fitter,\n",
    "                                              niters=2, fitshape=(11, 11), aperture_radius=ap_radius, \n",
    "                                              extra_output_cols=('sharpness', 'roundness2'))\n",
    "    result = phot(data)\n",
    "    \n",
    "    toc = time.perf_counter()\n",
    "    \n",
    "    print('Time needed to perform photometry:', '%.2f' % ((toc - tic) / 3600), 'hours')\n",
    "    print('Number of sources detected:', len(result))\n",
    "        \n",
    "    residual_image = phot.get_residual_image()\n",
    "    \n",
    "    # save the residual images as fits file:\n",
    "\n",
    "    if save_residuals:\n",
    "        hdu = fits.PrimaryHDU(residual_image)\n",
    "        hdul = fits.HDUList([hdu])\n",
    "    \n",
    "        residual_outname = 'residual_%s_%s.fits' % (det, filt)\n",
    "\n",
    "        hdul.writeto(os.path.join(res_dir, residual_outname))\n",
    "\n",
    "    # save the output photometry Tables\n",
    "\n",
    "    if save_output:\n",
    "\n",
    "        outname = 'phot_%s_%s.pkl' % (det, filt)\n",
    "        \n",
    "        tab = result.to_pandas()\n",
    "        tab.to_pickle(os.path.join(output_phot_dir, outname))\n",
    "    \n",
    "    return result, residual_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "934742c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# During the Webbinar we use a cutout of the whole image to speed up the reduction process\n",
    "\n",
    "data1 = data[0:500, 0:500]\n",
    "# data1 = data\n",
    "\n",
    "output_phot_dir = 'PHOT_OUTPUT/'\n",
    "\n",
    "if not os.path.exists(output_phot_dir):\n",
    "    os.makedirs(output_phot_dir)\n",
    "\n",
    "res_dir = 'RESIDUAL_IMAGES/'\n",
    "\n",
    "if not os.path.exists(res_dir):\n",
    "    os.makedirs(res_dir)\n",
    "\n",
    "if glob.glob(os.path.join(res_dir, 'residual*F115W.fits')):\n",
    "    print('Deleting Residual images from directory')\n",
    "    files = glob.glob(os.path.join(res_dir, 'residual*F115W.fits'))\n",
    "    for file in files:\n",
    "        os.remove(file)\n",
    "\n",
    "psf_phot_results, residual_image = psf_phot(data=data1, det=det, filt=filt, th=10, psf=psf_webbpsf_grid, \n",
    "                                            save_residuals=True, save_output=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fcd71de",
   "metadata": {},
   "source": [
    "### 6.1<font color='white'>-</font>PSF photometry output catalog<a class=\"anchor\" id=\"psf_cat\"></a> ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "861c03d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "psf_phot_results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6feb47a5",
   "metadata": {},
   "source": [
    "### 6.2<font color='white'>-</font>Display residual image<a class=\"anchor\" id=\"residual\"></a> ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8eb3b045",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(14, 14))\n",
    "\n",
    "ax1 = plt.subplot(2, 2, 1)\n",
    "\n",
    "plt.xlabel(\"X [px]\", fontdict=font2)\n",
    "plt.ylabel(\"Y [px]\", fontdict=font2)\n",
    "plt.title(filt, fontdict=font2)\n",
    "\n",
    "norm = simple_norm(data1, 'sqrt', percent=99.)\n",
    "ax1.imshow(data1, norm=norm, cmap='Greys')\n",
    "\n",
    "ax2 = plt.subplot(2, 2, 2)\n",
    "\n",
    "plt.xlabel(\"X [px]\", fontdict=font2)\n",
    "plt.ylabel(\"Y [px]\", fontdict=font2)\n",
    "plt.title('residuals', fontdict=font2)\n",
    "\n",
    "norm = simple_norm(data1, 'sqrt', percent=99.)\n",
    "ax2.imshow(residual_image, norm=norm, cmap='Greys')\n",
    "\n",
    "ax3 = plt.subplot(2, 2, 3)\n",
    "\n",
    "plt.xlabel(\"X [px]\", fontdict=font2)\n",
    "plt.ylabel(\"Y [px]\", fontdict=font2)\n",
    "plt.title(filt, fontdict=font2)\n",
    "\n",
    "norm = simple_norm(data, 'sqrt', percent=99.)\n",
    "ax3.imshow(data, norm=norm, cmap='Greys')\n",
    "\n",
    "ax4 = plt.subplot(2, 2, 4)\n",
    "\n",
    "if os.path.isfile('./residual_webbpsf_grid16_NRCB1_F115W.fits'):\n",
    "    res_f115w = './residual_webbpsf_grid16_NRCB1_F115W.fits'\n",
    "\n",
    "else:\n",
    "    print('Downloading F115W residual image')\n",
    "    \n",
    "    boxlink_res_f115w = 'https://stsci.box.com/shared/static/g4ffi7zowwlj91up4nkqxz38c1l4tpjx.fits'\n",
    "    boxfile_res_f115w = './residual_webbpsf_grid16_NRCB1_F115W.fits'\n",
    "    urllib.request.urlretrieve(boxlink_res_f115w, boxfile_res_f115w)\n",
    "    res_f115w = boxfile_res_f115w\n",
    "\n",
    "if os.path.isfile('./residual_webbpsf_grid16_NRCB1_F200W.fits'):\n",
    "    res_f200w = './residual_webbpsf_grid16_NRCB1_F200W.fits'\n",
    "\n",
    "else:\n",
    "    print('Downloading F200W residual image')\n",
    "    \n",
    "    boxlink_res_f200w = 'https://stsci.box.com/shared/static/mssn25cokiwfwco9f289nds7lennfgvv.fits'\n",
    "    boxfile_res_f200w = './residual_webbpsf_grid16_NRCB1_F200W.fits'\n",
    "    urllib.request.urlretrieve(boxlink_res_f200w, boxfile_res_f200w)\n",
    "    res_f200w = boxfile_res_f200w\n",
    "\n",
    "residual = res_f115w\n",
    "\n",
    "residual = fits.open(residual)\n",
    "res_data = residual[0].data\n",
    "\n",
    "plt.xlabel(\"X [px]\", fontdict=font2)\n",
    "plt.ylabel(\"Y [px]\", fontdict=font2)\n",
    "plt.title('residuals', fontdict=font2)\n",
    "\n",
    "ax4.imshow(res_data, norm=norm, cmap='Greys')\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad4208d4",
   "metadata": {},
   "source": [
    "7.<font color='white'>-</font>Exercise: perform PSF photometry on the other filter<a class=\"anchor\" id=\"exercise\"></a>\n",
    "------------------"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1425cab5",
   "metadata": {},
   "outputs": [],
   "source": [
    "det = 'NRCB1'\n",
    "filt = 'F200W'\n",
    "\n",
    "im = fits.open(dict_images[det][filt]['images'][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "837b064f",
   "metadata": {},
   "source": [
    "Display image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0c24ce0",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 12))\n",
    "   \n",
    "data_sb = im[1].data\n",
    "\n",
    "ax = plt.subplot(1, 1, 1)\n",
    "\n",
    "plt.xlabel(\"X [px]\", fontdict=font2)\n",
    "plt.ylabel(\"Y [px]\", fontdict=font2)\n",
    "plt.title(filt, fontdict=font2)\n",
    "norm = simple_norm(data_sb, 'sqrt', percent=99.)\n",
    "\n",
    "ax.imshow(data_sb, norm=norm, cmap='Greys')\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "985d35ee",
   "metadata": {},
   "source": [
    "Convert image to DN/s and apply PAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f9e0b15",
   "metadata": {},
   "outputs": [],
   "source": [
    "imh = im[1].header\n",
    "data = data_sb / imh['PHOTMJSR']\n",
    "print('Conversion factor from {units} to DN/s for filter {f}:'.format(units=imh['BUNIT'], f=filt), imh['PHOTMJSR'])\n",
    "area = im[4].data\n",
    "data = data * area"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95c3054b",
   "metadata": {},
   "source": [
    "Create a synthetic PSF with WebbPSF (single or grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45b5dea5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# single:\n",
    "\n",
    "psf_webbpsf_single = create_psf_model(det=det, filt=filt, fov=11, create_grid=False, save_psf=True, \n",
    "                                      detsampled=False)\n",
    "\n",
    "# grid\n",
    "\n",
    "psf_webbpsf_grid = create_psf_model(det=det, filt=filt, fov=11, create_grid=True, num=16, save_psf=True, \n",
    "                                    detsampled=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26d4dd66",
   "metadata": {},
   "source": [
    "Display the synthetic PSF (single)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c29ab0cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 12))\n",
    "\n",
    "ax = plt.subplot(1, 1, 1)\n",
    "\n",
    "norm_psf = simple_norm(psf_webbpsf_single.data[0], 'log', percent=99.)\n",
    "ax.set_title(filt, fontsize=40)\n",
    "ax.imshow(psf_webbpsf_single.data[0], norm=norm_psf)\n",
    "ax.set_xlabel('X [px]', fontsize=30)\n",
    "ax.set_ylabel('Y [px]', fontsize=30)\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "205105d6",
   "metadata": {},
   "source": [
    "Display the synthetic PSF (grid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a213dd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "webbpsf.gridded_library.display_psf_grid(psf_webbpsf_grid, zoom_in=False, figsize=(14, 14))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bcf8e0bb",
   "metadata": {},
   "source": [
    "Build an effective PSF\n",
    " - Find stars\n",
    " - Select sources using DaoStarFinder diagnostics\n",
    " - Create catalog for selected sources\n",
    " - Build effective PSF\n",
    " - Display the effective PSF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fcc902f",
   "metadata": {},
   "outputs": [],
   "source": [
    "tic = time.perf_counter()\n",
    "\n",
    "found_stars = find_stars(det=det, filt=filt, threshold=10, var_bkg=False)\n",
    "\n",
    "toc = time.perf_counter()\n",
    "\n",
    "print(\"Elapsed Time for finding stars:\", toc - tic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "211463ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 8))\n",
    "plt.clf()\n",
    "\n",
    "ax1 = plt.subplot(2, 1, 1)\n",
    "\n",
    "ax1.set_xlabel('mag', fontdict=font2)\n",
    "ax1.set_ylabel('sharpness', fontdict=font2)\n",
    "\n",
    "xlim0 = np.min(found_stars['mag']) - 0.25\n",
    "xlim1 = np.max(found_stars['mag']) + 0.25\n",
    "ylim0 = np.min(found_stars['sharpness']) - 0.15\n",
    "ylim1 = np.max(found_stars['sharpness']) + 0.15\n",
    "\n",
    "ax1.set_xlim(xlim0, xlim1)\n",
    "ax1.set_ylim(ylim0, ylim1)\n",
    "\n",
    "ax1.xaxis.set_major_locator(ticker.AutoLocator())\n",
    "ax1.xaxis.set_minor_locator(ticker.AutoMinorLocator())\n",
    "ax1.yaxis.set_major_locator(ticker.AutoLocator())\n",
    "ax1.yaxis.set_minor_locator(ticker.AutoMinorLocator())\n",
    "\n",
    "ax1.scatter(found_stars['mag'], found_stars['sharpness'], s=10, color='k')\n",
    "\n",
    "sh_inf = 0.57\n",
    "sh_sup = 0.70\n",
    "mag_lim = -3.0\n",
    "\n",
    "ax1.plot([xlim0, xlim1], [sh_sup, sh_sup], color='r', lw=3, ls='--')\n",
    "ax1.plot([xlim0, xlim1], [sh_inf, sh_inf], color='r', lw=3, ls='--')\n",
    "ax1.plot([mag_lim, mag_lim], [ylim0, ylim1], color='r', lw=3, ls='--')\n",
    "\n",
    "ax2 = plt.subplot(2, 1, 2)\n",
    "\n",
    "ax2.set_xlabel('mag', fontdict=font2)\n",
    "ax2.set_ylabel('roundness', fontdict=font2)\n",
    "\n",
    "ylim0 = np.min(found_stars['roundness2']) - 0.25\n",
    "ylim1 = np.max(found_stars['roundness2']) - 0.25\n",
    "\n",
    "ax2.set_xlim(xlim0, xlim1)\n",
    "ax2.set_ylim(ylim0, ylim1)\n",
    "\n",
    "ax2.xaxis.set_major_locator(ticker.AutoLocator())\n",
    "ax2.xaxis.set_minor_locator(ticker.AutoMinorLocator())\n",
    "ax2.yaxis.set_major_locator(ticker.AutoLocator())\n",
    "ax2.yaxis.set_minor_locator(ticker.AutoMinorLocator())\n",
    "\n",
    "round_inf = -0.25\n",
    "round_sup = 0.25\n",
    "\n",
    "ax2.scatter(found_stars['mag'], found_stars['roundness2'], s=10, color='k')\n",
    "\n",
    "ax2.plot([xlim0, xlim1], [round_sup, round_sup], color='r', lw=3, ls='--')\n",
    "ax2.plot([xlim0, xlim1], [round_inf, round_inf], color='r', lw=3, ls='--')\n",
    "ax2.plot([mag_lim, mag_lim], [ylim0, ylim1], color='r', lw=3, ls='--')\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60374d6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "mask = ((found_stars['mag'] < mag_lim) & (found_stars['roundness2'] > round_inf)\n",
    "        & (found_stars['roundness2'] < round_sup) & (found_stars['sharpness'] > sh_inf) \n",
    "        & (found_stars['sharpness'] < sh_sup))\n",
    "\n",
    "found_stars_sel = found_stars[mask]\n",
    "found_stars_sel_f200w = found_stars_sel\n",
    "\n",
    "print('Number of stars selected to build ePSF:', len(found_stars_sel))\n",
    "\n",
    "# if we include the separation criteria:\n",
    "\n",
    "d = []\n",
    "\n",
    "# we do not want any stars in a 10 px radius. \n",
    "\n",
    "min_sep = 10\n",
    "\n",
    "x_tot = found_stars['xcentroid']\n",
    "y_tot = found_stars['ycentroid']\n",
    "\n",
    "for xx, yy in zip(found_stars_sel['xcentroid'], found_stars_sel['ycentroid']):\n",
    "\n",
    "    sep = []\n",
    "    dist = np.sqrt((x_tot - xx)**2 + (y_tot - yy)**2)\n",
    "    sep = np.sort(dist)[1:2][0]\n",
    "    d.append(sep)\n",
    "\n",
    "found_stars_sel['min distance'] = d\n",
    "mask_dist = (found_stars_sel['min distance'] > min_sep)\n",
    "\n",
    "found_stars_sel2 = found_stars_sel[mask_dist]\n",
    "found_stars_sel2_f200w = found_stars_sel2\n",
    "\n",
    "print('Number of stars selected to build ePSF \\\n",
    "including \"mimimum distance closest neighbour\" selection):', len(found_stars_sel2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32665568",
   "metadata": {},
   "outputs": [],
   "source": [
    "epsf = build_epsf(det=det, filt=filt, size=11, found_table=found_stars_sel, oversample=4, iters=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdf6108c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 12))\n",
    "\n",
    "ax = plt.subplot(1, 1, 1)\n",
    "\n",
    "norm_epsf = simple_norm(epsf.data, 'log', percent=99.)\n",
    "plt.title(filt, fontsize=30)\n",
    "ax.imshow(epsf.data, norm=norm_epsf)\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f6f5be9",
   "metadata": {},
   "source": [
    "Perform PSF photometry \n",
    "\n",
    "**Note**: remember to create a cutout of the original image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bf393ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "data1 = data[0:500, 0:500]\n",
    "# data1 = data\n",
    "\n",
    "output_phot_dir = 'PHOT_OUTPUT/'\n",
    "\n",
    "if not os.path.exists(output_phot_dir):\n",
    "    os.makedirs(output_phot_dir)\n",
    "\n",
    "res_dir = 'RESIDUAL_IMAGES/'\n",
    "\n",
    "if not os.path.exists(res_dir):\n",
    "    os.makedirs(res_dir)\n",
    "\n",
    "if glob.glob(os.path.join(res_dir, 'residual*F200W.fits')):\n",
    "    print('Deleting Residual images from directory')\n",
    "    files = glob.glob(os.path.join(res_dir, 'residual*F200W.fits'))\n",
    "    for file in files:\n",
    "        os.remove(file)\n",
    "\n",
    "psf_phot_results, residual_image = psf_phot(data=data1, det=det, filt=filt, th=10, psf=epsf, \n",
    "                                            save_residuals=True, save_output=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f56d91d6",
   "metadata": {},
   "source": [
    "Display the residual image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66cded6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(14, 14))\n",
    "\n",
    "ax1 = plt.subplot(2, 2, 1)\n",
    "\n",
    "plt.xlabel(\"X [px]\", fontdict=font2)\n",
    "plt.ylabel(\"Y [px]\", fontdict=font2)\n",
    "plt.title(filt, fontdict=font2)\n",
    "\n",
    "norm = simple_norm(data1, 'sqrt', percent=99.)\n",
    "ax1.imshow(data1, norm=norm, cmap='Greys')\n",
    "\n",
    "ax2 = plt.subplot(2, 2, 2)\n",
    "\n",
    "plt.xlabel(\"X [px]\", fontdict=font2)\n",
    "plt.ylabel(\"Y [px]\", fontdict=font2)\n",
    "plt.title('residuals', fontdict=font2)\n",
    "\n",
    "norm = simple_norm(data1, 'sqrt', percent=99.)\n",
    "ax2.imshow(residual_image, norm=norm, cmap='Greys')\n",
    "\n",
    "ax3 = plt.subplot(2, 2, 3)\n",
    "\n",
    "plt.xlabel(\"X [px]\", fontdict=font2)\n",
    "plt.ylabel(\"Y [px]\", fontdict=font2)\n",
    "plt.title(filt, fontdict=font2)\n",
    "\n",
    "norm = simple_norm(data, 'sqrt', percent=99.)\n",
    "ax3.imshow(data, norm=norm, cmap='Greys')\n",
    "\n",
    "ax4 = plt.subplot(2, 2, 4)\n",
    "\n",
    "if os.path.isfile('./residual_webbpsf_grid16_NRCB1_F115W.fits'):\n",
    "    res_f115w = './residual_webbpsf_grid16_NRCB1_F115W.fits'\n",
    "\n",
    "else:\n",
    "    print('Downloading F115W residual image')\n",
    "    \n",
    "    boxlink_res_f115w = 'https://stsci.box.com/shared/static/g4ffi7zowwlj91up4nkqxz38c1l4tpjx.fits'\n",
    "    boxfile_res_f115w = './residual_webbpsf_grid16_NRCB1_F115W.fits'\n",
    "    urllib.request.urlretrieve(boxlink_res_f115w, boxfile_res_f115w)\n",
    "    res_f115w = boxfile_res_f115w\n",
    "\n",
    "if os.path.isfile('./residual_webbpsf_grid16_NRCB1_F200W.fits'):\n",
    "    res_f200w = './residual_webbpsf_grid16_NRCB1_F200W.fits'\n",
    "\n",
    "else:\n",
    "    print('Downloading F200W residual image')\n",
    "    \n",
    "    boxlink_res_f200w = 'https://stsci.box.com/shared/static/mssn25cokiwfwco9f289nds7lennfgvv.fits'\n",
    "    boxfile_res_f200w = './residual_webbpsf_grid16_NRCB1_F200W.fits'\n",
    "    urllib.request.urlretrieve(boxlink_res_f200w, boxfile_res_f200w)\n",
    "    res_f200w = boxfile_res_f200w\n",
    "\n",
    "residual = res_f200w\n",
    "\n",
    "residual = fits.open(residual)\n",
    "res_data = residual[0].data\n",
    "\n",
    "plt.xlabel(\"X [px]\", fontdict=font2)\n",
    "plt.ylabel(\"Y [px]\", fontdict=font2)\n",
    "plt.title('residuals', fontdict=font2)\n",
    "\n",
    "ax4.imshow(res_data, norm=norm, cmap='Greys')\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d2e034c",
   "metadata": {},
   "source": [
    "8.<font color='white'>-</font>Bonus part I: create your first NIRCam Color-Magnitude Diagram<a class=\"anchor\" id=\"bonusI\"></a>\n",
    "------------------"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58259625",
   "metadata": {},
   "source": [
    "### 8.1<font color='white'>-</font>Load images and output catalogs<a class=\"anchor\" id=\"load_data\"></a> ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5440b51",
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.isfile('./phot_webbpsf_grid16_NRCB1_F115W.pkl'):\n",
    "    phot_f115w = './phot_webbpsf_grid16_NRCB1_F115W.pkl'\n",
    "\n",
    "else:\n",
    "    print('Downloading F115W PSF photometry')\n",
    "    \n",
    "    boxlink_ph_f115w = 'https://stsci.box.com/shared/static/f4s4ziwh7tb3827g0ac362swwta28lti.pkl'\n",
    "    boxfile_ph_f115w = './phot_webbpsf_grid16_NRCB1_F115W.pkl'\n",
    "    urllib.request.urlretrieve(boxlink_ph_f115w, boxfile_ph_f115w)\n",
    "    phot_f115w = boxfile_ph_f115w\n",
    "\n",
    "if os.path.isfile('./phot_webbpsf_grid16_NRCB1_F200W.pkl'):\n",
    "    phot_f200w = './phot_webbpsf_grid16_NRCB1_F200W.pkl'\n",
    "\n",
    "else:\n",
    "    print('Downloading F200W PSF Photometry')\n",
    "    \n",
    "    boxlink_ph_f200w = 'https://stsci.box.com/shared/static/983dxnxqz594ogn00e6m7ek7vq22gifr.pkl'\n",
    "    boxfile_ph_f200w = './phot_webbpsf_grid16_NRCB1_F200W.pkl'\n",
    "    urllib.request.urlretrieve(boxlink_ph_f200w, boxfile_ph_f200w)\n",
    "    phot_f200w = boxfile_ph_f200w\n",
    "\n",
    "\n",
    "ph_f115w = pd.read_pickle(phot_f115w)\n",
    "ph_f200w = pd.read_pickle(phot_f200w)\n",
    "\n",
    "results_f115w = QTable.from_pandas(ph_f115w)\n",
    "results_f200w = QTable.from_pandas(ph_f200w)\n",
    "\n",
    "image_f115w = ImageModel('./jw00042001001_01101_00001_nrcb1_cal.fits')\n",
    "image_f200w = ImageModel('./jw00042001001_01101_00005_nrcb1_cal.fits')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72643e44",
   "metadata": {},
   "source": [
    "### 8.2<font color='white'>-</font>Cross-match PSF photometry catalogs<a class=\"anchor\" id=\"cross_match\"></a> ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "003f1010",
   "metadata": {},
   "source": [
    "We select only stars with positive flux and we use the information from the image WCS to transform detector coordinates (x,y) to celestial coordinate (RA,Dec). \n",
    "\n",
    "We use the [SkyCoord](https://docs.astropy.org/en/stable/api/astropy.coordinates.SkyCoord.html) class and the [match_coordinates_sky](https://docs.astropy.org/en/stable/api/astropy.coordinates.match_coordinates_sky.html) function to finds the nearest on-sky matches between the set of catalog coordinates.\n",
    "\n",
    "We impose that a star is the same in both catalogs if the separation is < than `max_sep`, where `max_sep` is 0.015 arcsec (i.e., 0.5 px). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68410f33",
   "metadata": {},
   "outputs": [],
   "source": [
    "mask_f115w = ((results_f115w['x_fit'] > 0) & (results_f115w['x_fit'] < 2048) & \n",
    "              (results_f115w['y_fit'] > 0) & (results_f115w['y_fit'] < 2048) & \n",
    "              (results_f115w['flux_fit'] > 0))\n",
    "\n",
    "result_clean_f115w = results_f115w[mask_f115w]\n",
    "\n",
    "ra_f115w, dec_f115w = image_f115w.meta.wcs(result_clean_f115w['x_fit'], result_clean_f115w['y_fit'])\n",
    "radec_f115w = SkyCoord(ra_f115w, dec_f115w, unit='deg')\n",
    "result_clean_f115w['radec'] = radec_f115w\n",
    "\n",
    "mask_f200w = ((results_f200w['x_fit'] > 0) & (results_f200w['x_fit'] < 2048) & \n",
    "              (results_f200w['y_fit'] > 0) & (results_f200w['y_fit'] < 2048) & \n",
    "              (results_f200w['flux_fit'] > 0))\n",
    "\n",
    "result_clean_f200w = results_f200w[mask_f200w]\n",
    "\n",
    "ra_f200w, dec_f200w = image_f200w.meta.wcs(result_clean_f200w['x_fit'], result_clean_f200w['y_fit'])\n",
    "radec_f200w = SkyCoord(ra_f200w, dec_f200w, unit='deg')\n",
    "\n",
    "result_clean_f200w['radec'] = radec_f200w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3472fe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_sep = 0.015 * u.arcsec\n",
    "\n",
    "filt1 = 'F115W'\n",
    "filt2 = 'F200W'\n",
    "\n",
    "res1 = result_clean_f115w\n",
    "res2 = result_clean_f200w\n",
    "\n",
    "idx, d2d, _ = match_coordinates_sky(res1['radec'], res2['radec'])\n",
    "\n",
    "sep_constraint = d2d < max_sep\n",
    "\n",
    "match_phot_single = Table()\n",
    "\n",
    "x_0_f115w = res1['x_0'][sep_constraint]\n",
    "y_0_f115w = res1['y_0'][sep_constraint]\n",
    "x_fit_f115w = res1['x_fit'][sep_constraint]\n",
    "y_fit_f115w = res1['y_fit'][sep_constraint]\n",
    "radec_f115w = res1['radec'][sep_constraint]\n",
    "mag_f115w = (-2.5 * np.log10(res1['flux_fit']))[sep_constraint]\n",
    "emag_f115w = (1.086 * (res1['flux_unc'] / res1['flux_fit']))[sep_constraint]\n",
    "\n",
    "x_0_f200w = res2['x_0'][idx[sep_constraint]]\n",
    "y_0_f200w = res2['y_0'][idx[sep_constraint]]\n",
    "x_fit_f200w = res2['x_fit'][idx[sep_constraint]]\n",
    "y_fit_f200w = res2['y_fit'][idx[sep_constraint]]\n",
    "radec_f200w = res2['radec'][idx][sep_constraint]\n",
    "mag_f200w = (-2.5 * np.log10(res2['flux_fit']))[idx[sep_constraint]]\n",
    "emag_f200w = (1.086 * (res2['flux_unc'] / res2['flux_fit']))[idx[sep_constraint]]\n",
    "\n",
    "match_phot_single['x_0_' + filt1] = x_0_f115w\n",
    "match_phot_single['y_0_' + filt1] = y_0_f115w\n",
    "match_phot_single['x_fit_' + filt1] = x_fit_f115w\n",
    "match_phot_single['y_fit_' + filt1] = y_fit_f115w\n",
    "match_phot_single['radec_' + filt1] = radec_f115w\n",
    "match_phot_single['mag_' + filt1] = mag_f115w\n",
    "match_phot_single['emag_' + filt1] = emag_f115w\n",
    "match_phot_single['x_0_' + filt2] = x_0_f200w\n",
    "match_phot_single['y_0_' + filt2] = y_0_f200w\n",
    "match_phot_single['x_fit_' + filt2] = x_fit_f200w\n",
    "match_phot_single['y_fit_' + filt2] = y_fit_f200w\n",
    "match_phot_single['radec_' + filt2] = radec_f200w\n",
    "match_phot_single['mag_' + filt2] = mag_f200w\n",
    "match_phot_single['emag_' + filt2] = emag_f200w\n",
    "\n",
    "print('Number of sources in common between the two filters:', len(match_phot_single)) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "653892e3",
   "metadata": {},
   "source": [
    "### 8.3<font color='white'>-</font>Load input catalogs<a class=\"anchor\" id=\"load_input\"></a> ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf029b02",
   "metadata": {},
   "source": [
    "If you are analyzing the LW images, the links for the input catalogs are:\n",
    "\n",
    "input phot F277W link: https://stsci.box.com/shared/static/f8e625es3nwj3p759fquptaqm7lofh1j.list\n",
    "\n",
    "input phot F277W name: jw00042001001_01101_00001_nrcb5_uncal_pointsources.list\n",
    "\n",
    "input phot F444W link: https://stsci.box.com/shared/static/xgd7ofhr1zti6at0ss1e7y67g9yncxqr.list\n",
    "\n",
    "input phot F444W name: jw00042001001_01101_00005_nrcb5_uncal_pointsources.list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34d48853",
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.isfile('./jw00042001001_01101_00001_nrcb1_uncal_pointsources.list'):\n",
    "    input_catalog_f115w = './jw00042001001_01101_00001_nrcb1_uncal_pointsources.list'\n",
    "\n",
    "else:\n",
    "    print('Downloading F115W input catalog photometry')\n",
    "    \n",
    "    boxlink_inpcat_f115w = 'https://stsci.box.com/shared/static/1qdcq418fcmeekkzlw8uqmbz1656t3kd.list'\n",
    "    boxfile_inpcat_f115w = './jw00042001001_01101_00001_nrcb1_uncal_pointsources.list'\n",
    "    urllib.request.urlretrieve(boxlink_inpcat_f115w, boxfile_inpcat_f115w)\n",
    "    input_catalog_f115w = boxfile_inpcat_f115w\n",
    "\n",
    "if os.path.isfile('./jw00042001001_01101_00005_nrcb1_uncal_pointsources.list'):\n",
    "    input_catalog_f200w = './jw00042001001_01101_00005_nrcb1_uncal_pointsources.list'\n",
    "\n",
    "else:\n",
    "    print('Downloading F200W input catalog photometry')\n",
    "    \n",
    "    boxlink_inpcat_f200w = 'https://stsci.box.com/shared/static/saeqibfbeccpv5gt9ndlb826o68ikt4b.list'\n",
    "    boxfile_inpcat_f200w = './jw00042001001_01101_00005_nrcb1_uncal_pointsources.list'\n",
    "    urllib.request.urlretrieve(boxlink_inpcat_f200w, boxfile_inpcat_f200w)\n",
    "    input_catalog_f200w = boxfile_inpcat_f200w\n",
    "\n",
    "input_cat_f115w = pd.read_csv(input_catalog_f115w, header=None, sep='\\s+', \n",
    "                              names=['index', 'ra_in', 'dec_in', 'f115w_in'], comment='#', skiprows=2, \n",
    "                              usecols=(0, 3, 4, 7))\n",
    "\n",
    "\n",
    "input_cat_f200w = pd.read_csv(input_catalog_f200w, header=None, sep='\\s+', \n",
    "                              names=['index', 'ra_in', 'dec_in', 'f200w_in'], comment='#', skiprows=2, \n",
    "                              usecols=(0, 3, 4, 7))\n",
    "\n",
    "\n",
    "input_cat_f200w.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38f43944",
   "metadata": {},
   "source": [
    "### 8.4<font color='white'>-</font>Cross-match input catalogs<a class=\"anchor\" id=\"cross_match_input\"></a> ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abd65e2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "radec_input_f115w = SkyCoord(input_cat_f115w['ra_in'], input_cat_f115w['dec_in'], unit='deg')\n",
    "radec_input_f200w = SkyCoord(input_cat_f200w['ra_in'], input_cat_f200w['dec_in'], unit='deg')\n",
    "\n",
    "idx_inp, d2d_inp, _ = match_coordinates_sky(radec_input_f115w, radec_input_f200w)\n",
    "\n",
    "max_sep = 0.015 * u.arcsec\n",
    "\n",
    "sep_constraint_inp = d2d_inp < max_sep\n",
    "\n",
    "f115w_inp = input_cat_f115w['f115w_in'][sep_constraint_inp]\n",
    "f200w_inp = input_cat_f200w['f200w_in'][idx_inp[sep_constraint_inp]]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63e1dce7",
   "metadata": {},
   "source": [
    "### 8.5<font color='white'>-</font>Instrumental Color-Magnitude Diagram<a class=\"anchor\" id=\"cmd\"></a> ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61b42d56",
   "metadata": {},
   "source": [
    "We compare the input (calibrated) color-magnitude diagram (CMD) with the instrumental CMD retrieved from the PSF photometry analysis. To obtain a calibrated final color-magnitude diagram, we need to derive the photometric zeropoints for the ouput catalogs (not covered in this JWebbinar)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ba96df8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 16))\n",
    "plt.clf()\n",
    "\n",
    "ax1 = plt.subplot(1, 2, 1)\n",
    "\n",
    "xlim0 = -0.25 \n",
    "xlim1 = 1.75 \n",
    "ylim0 = 24\n",
    "ylim1 = 15 \n",
    "\n",
    "ax1.set_xlim(xlim0, xlim1)\n",
    "ax1.set_ylim(ylim0, ylim1)\n",
    "\n",
    "ax1.xaxis.set_major_locator(ticker.AutoLocator())\n",
    "ax1.xaxis.set_minor_locator(ticker.AutoMinorLocator())\n",
    "ax1.yaxis.set_major_locator(ticker.AutoLocator())\n",
    "ax1.yaxis.set_minor_locator(ticker.AutoMinorLocator())\n",
    "\n",
    "ax1.scatter(f115w_inp - f200w_inp, f115w_inp, s=1, color='k')\n",
    "\n",
    "ax1.set_xlabel(filt1+' - '+filt2, fontdict=font2)\n",
    "ax1.set_ylabel(filt1, fontdict=font2)\n",
    "ax1.text(xlim0 + 0.15, ylim1 + 0.25, 'Input', fontdict=font2)\n",
    "ax1.text(xlim0 + 0.15, ylim1 + 0.50, 'N sources = ' + str(len(f115w_inp)), fontdict=font2)\n",
    "\n",
    "ax2 = plt.subplot(1, 2, 2)\n",
    "\n",
    "xlim0 = -0.8\n",
    "xlim1 = 1.2\n",
    "ylim0 = -2\n",
    "ylim1 = -11\n",
    "\n",
    "ax2.set_xlim(xlim0, xlim1)\n",
    "ax2.set_ylim(ylim0, ylim1)\n",
    "\n",
    "ax2.xaxis.set_major_locator(ticker.AutoLocator())\n",
    "ax2.xaxis.set_minor_locator(ticker.AutoMinorLocator())\n",
    "ax2.yaxis.set_major_locator(ticker.AutoLocator())\n",
    "ax2.yaxis.set_minor_locator(ticker.AutoMinorLocator())\n",
    "\n",
    "f115w_single = match_phot_single['mag_' + filt1]\n",
    "f200w_single = match_phot_single['mag_' + filt2]\n",
    "\n",
    "ax2.scatter(f115w_single - f200w_single, f115w_single, s=1, color='k')\n",
    "\n",
    "ax2.set_xlabel(filt1 + '_inst - ' + filt2 + '_inst', fontdict=font2)\n",
    "ax2.set_ylabel(filt1 + '_inst', fontdict=font2)\n",
    "ax2.text(xlim0 + 0.15, ylim1 + 0.25, 'Output', fontdict=font2)\n",
    "ax2.text(xlim0 + 0.15, ylim1 + 0.50, 'N sources = ' + str(len(f115w_single)), fontdict=font2)\n",
    "\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8783d290",
   "metadata": {},
   "source": [
    "9.<font color='white'>-</font>Bonus part II: create a grid of empirical PSFs <a class=\"anchor\" id=\"bonusII\"></a>\n",
    "------------------"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "151442ac",
   "metadata": {},
   "source": [
    "### 9.1<font color='white'>-</font>Count stars in N x N grid<a class=\"anchor\" id=\"count_stars\"></a> ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52ee13e7",
   "metadata": {},
   "source": [
    "The purpose of the function `count_stars_grid` is to count how many good PSF stars are in cell of a N x N grid. The function starts from  a grid of size N x N (where N = sqrt(**num_psfs**)) and iterate until the minimum grid size 2 x 2. Depending on the number of PSF stars that the users want in each cell of the grid, they can choose the appropriate grid size or modify the threshold values and/or the selection parameters adopted during the stars detection, in Sections 5.3, 5.4. \n",
    "\n",
    "The minimum number of PSF stars needed in each cell can also be set using the parameter **min_numpsfs_stars**. Useful when inspecting the plot, since in the cells with a number of PSF stars < **min_numpsfs_stars**, the value is reported in RED. Moreover, when `verbose = True`, it is easier to identify for each N x N combination, if and which cells have not enough PSF stars.\n",
    "\n",
    "This function returns sqrt(**num_psfs**) - 1 figures showing the number of PSFs stars in each cell for all the N x N combination."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e12d81ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_centers(num):\n",
    "    points = int(((data.shape[0] / num) / 2) - 1)\n",
    "    x_center = np.arange(points, 2 * points * num, 2 * points)\n",
    "    y_center = np.arange(points, 2 * points * num, 2 * points)\n",
    "\n",
    "    centers = np.array(np.meshgrid(x_center, y_center)).T.reshape(-1, 2)\n",
    "\n",
    "    return points, centers\n",
    "\n",
    "\n",
    "def count_stars_grid(num_psfs=4, min_numpsf_stars=40, size=11, verbose=True, savefig=True):\n",
    "\n",
    "    # calculate the number of stars from find_stars in each cell of the grid. The maximum number of cell\n",
    "    # is defined by num_psfs and the function iterate from N x N (where N = sqrt(num_psfs)) until a 2 x 2 grid.\n",
    "\n",
    "    if np.sqrt(num_psfs).is_integer():\n",
    "        grid_points = int(np.sqrt(num_psfs))\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"You must choose a square number of cells to create (E.g. 9, 16, etc.)\")\n",
    "\n",
    "    num_grid = np.arange(2, grid_points + 1, 1)\n",
    "    num_grid = num_grid[::-1]\n",
    "\n",
    "    for num in num_grid:\n",
    "        print(\"--------------------\")\n",
    "        print(\"\")\n",
    "        print(\"Calculating the number of PSF stars in a %d x %d grid:\" % (num, num))\n",
    "        print(\"\")\n",
    "\n",
    "        s = (data.shape[1], data.shape[0])\n",
    "        temp_arr = np.zeros(s)\n",
    "        num_psfs_stars = []\n",
    "\n",
    "        points, centers = find_centers(num)\n",
    "\n",
    "        for n, val in enumerate(centers):\n",
    "\n",
    "            x = found_stars_sel['xcentroid']\n",
    "            y = found_stars_sel['ycentroid']\n",
    "\n",
    "            half_size = (size - 1) / 2\n",
    "\n",
    "            lim1 = int(val[0] - points + half_size)\n",
    "            lim2 = int(val[0] + points - half_size)\n",
    "            lim3 = int(val[1] - points + half_size)\n",
    "            lim4 = int(val[1] + points - half_size)\n",
    "\n",
    "            number_psf_stars = (x > lim1) & (x < lim2) & (y > lim3) & (y < lim4)\n",
    "            count_psfs_stars = np.count_nonzero(number_psf_stars)\n",
    "\n",
    "            lim_x1 = int(lim1 - half_size)\n",
    "            lim_x2 = int(lim2 + half_size)\n",
    "            lim_y1 = int(lim3 - half_size)\n",
    "            lim_y2 = int(lim4 + half_size)\n",
    "\n",
    "            if verbose:\n",
    "\n",
    "                if np.count_nonzero(number_psf_stars) < min_numpsf_stars:\n",
    "                    print('Center Coordinates of grid cell {:d} are ({:d}, {:d}) --- Not enough stars in the cell ' \n",
    "                          '(number of stars < {:d})'.format(n + 1, val[0], val[1], min_numpsf_stars))\n",
    "\n",
    "                else:\n",
    "                    print(f'Center Coordinate of grid cell {n + 1:d} are ({val[0]:d}, {val[1]:d})' \n",
    "                          '--- Number of stars:', np.count_nonzero(number_psf_stars))\n",
    "                    print(\"\")\n",
    "\n",
    "            temp_arr[lim_y1:lim_y2, lim_x1:lim_x2] = count_psfs_stars\n",
    "            num_psfs_stars.append(count_psfs_stars)\n",
    "\n",
    "        if savefig:\n",
    "            plot_count_grid(temp_arr, num, num_psfs_stars, centers)\n",
    "\n",
    "\n",
    "def plot_count_grid(arr, num, nstars, centers):\n",
    "    \n",
    "    plt.clf()\n",
    "\n",
    "    from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "    plt.figure(figsize=(10, 10))\n",
    "    ax = plt.subplot(1, 1, 1)\n",
    "\n",
    "    plt.xlabel('X [px]', fontsize=20)\n",
    "    plt.ylabel('Y [px]', fontsize=20)\n",
    "    plt.title('%dx%d grid - ' % (num, num) + det + ' - ' + filt, fontsize=25)\n",
    "    im = ax.imshow(arr, origin='lower', vmin=np.min(arr[arr > 0]), vmax=np.max(arr))\n",
    "    for i in range(num ** 2):\n",
    "        if nstars[i] < 40:\n",
    "            ax.text(centers[i][0] - 100, centers[i][1] - 50, \"%d\" % nstars[i], c='r', fontsize=30)\n",
    "        else:\n",
    "            ax.text(centers[i][0] - 100, centers[i][1] - 50, \"%d\" % nstars[i], c='w', fontsize=30)\n",
    "    ax.text(2300, 750, \"# of PSF stars\", rotation=270, fontsize=25)\n",
    "    divider = make_axes_locatable(ax)\n",
    "    cax = divider.append_axes(\"right\", size=\"5%\", pad=0.05)\n",
    "    plt.colorbar(im, cax=cax)\n",
    "\n",
    "    plt.tight_layout()\n",
    "\n",
    "    filename = 'number_PSFstars_%dx%dgrid_%s_%s.pdf' % (num, num, det, filt)\n",
    "\n",
    "    plt.savefig(os.path.join(figures_dir, filename))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2a4202d",
   "metadata": {},
   "source": [
    "For this example we use the image for filter F115W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4b6878c",
   "metadata": {},
   "outputs": [],
   "source": [
    "det = 'NRCB1'\n",
    "filt = 'F115W'\n",
    "\n",
    "im = fits.open(dict_images[det][filt]['images'][0])\n",
    "data = im[1].data\n",
    "found_stars_sel = found_stars_sel_f115w\n",
    "\n",
    "figures_dir = 'FIGURES/'\n",
    "\n",
    "if not os.path.exists(figures_dir):\n",
    "    os.makedirs(figures_dir)\n",
    "\n",
    "count_stars_grid(num_psfs=25, min_numpsf_stars=40, size=11, verbose=True, savefig=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "499bea37",
   "metadata": {},
   "source": [
    "### 9.2<font color='white'>-</font>Build effective PSF (single or grid)<a class=\"anchor\" id=\"epsf_grid\"></a> ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e446c23",
   "metadata": {},
   "source": [
    "This function creates a grid of PSFs with EPSFBuilder (or a single PSF, when **num_psfs**=1). The function returns a GriddedEPSFModel object containing a 3D array of N  ×  n  ×  n. The 3D array represents the N number of 2D n  ×  n ePSFs created. It includes a grid_xypos key which will state the position of the PSF on the detector for each of the PSFs. The order of the tuples in grid_xypos refers to the number the PSF is in the 3D array."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "113579fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_epsf_grid(num_psfs=4, size=11, oversample=4, save_epsf=True, savefig=True, overwrite=True):\n",
    "\n",
    "    if np.sqrt(num_psfs).is_integer():\n",
    "        num_grid = int(np.sqrt(num_psfs))\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"You must choose a square number of cells to create (E.g. 9, 16, etc.)\")\n",
    "\n",
    "    points, centers = find_centers(num_grid)\n",
    "\n",
    "    epsf_size = size * oversample\n",
    "    epsf_arr = np.empty((num_grid ** 2, epsf_size, epsf_size))\n",
    "\n",
    "    for i, val in enumerate(centers):\n",
    "\n",
    "        x = found_stars_sel['xcentroid']\n",
    "        y = found_stars_sel['ycentroid']\n",
    "\n",
    "        half_size = (size - 1) / 2\n",
    "\n",
    "        lim1 = int(val[0] - points + half_size)\n",
    "        lim2 = int(val[0] + points - half_size)\n",
    "        lim3 = int(val[1] - points + half_size)\n",
    "        lim4 = int(val[1] + points - half_size)\n",
    "\n",
    "        mask = ((x > lim1) & (x < lim2) & (y > lim3) & (y < lim4))\n",
    "\n",
    "        stars_tbl = Table()\n",
    "        stars_tbl['x'] = x[mask]\n",
    "        stars_tbl['y'] = y[mask]\n",
    "        print('Number of sources in cell %d used to build the ePSF:' % (i + 1), len(stars_tbl['x']))\n",
    "\n",
    "        nddata = NDData(data=data_bkgsub)\n",
    "        stars = extract_stars(nddata, stars_tbl, size=size)\n",
    "\n",
    "        print(\"Creating ePSF for cell %d - Coordinates (%d, %d)\" % (i + 1, val[0], val[1]))\n",
    "        print(\"\")\n",
    "\n",
    "        epsf_builder = EPSFBuilder(oversampling=oversample, maxiters=3, progress_bar=False)\n",
    "\n",
    "        epsf, fitted_stars = epsf_builder(stars)\n",
    "\n",
    "        epsf_arr[i, :, :] = epsf.data[:epsf.shape[0]-1, 1:epsf.shape[0]]\n",
    "\n",
    "        meta = OrderedDict()\n",
    "        meta[\"DETECTOR\"] = (det, \"Detector name\")\n",
    "        meta[\"FILTER\"] = (filt, \"Filter name\")\n",
    "        meta[\"NUM_PSFS\"] = (num_grid ** 2, \"The total number of ePSFs\")\n",
    "        for h, loc in enumerate(centers):\n",
    "            loc = np.asarray(loc, dtype=float)\n",
    "\n",
    "            meta[\"DET_YX{}\".format(h)] = (str((loc[1], loc[0])), \n",
    "                                          \"The #{} PSF's (y,x) detector pixel position\".format(h))\n",
    "\n",
    "        meta[\"OVERSAMP\"] = (oversample, \"Oversampling Factor in EPSFBuilder\")\n",
    "\n",
    "        model_epsf = create_model(epsf_arr, meta)\n",
    "\n",
    "    if savefig:\n",
    "        plot_epsf(model_epsf, num_psfs)\n",
    "\n",
    "    if save_epsf:\n",
    "        writeto(epsf_arr, meta, num_psfs)\n",
    "\n",
    "        return model_epsf\n",
    "\n",
    "\n",
    "def writeto(data, meta, num_psfs, overwrite=True):\n",
    "\n",
    "    primaryhdu = fits.PrimaryHDU(data)\n",
    "\n",
    "    # Convert meta dictionary to header\n",
    "    tuples = [(a, b, c) for (a, (b, c)) in meta.items()]\n",
    "    primaryhdu.header.extend(tuples)\n",
    "\n",
    "    # Add extra descriptors for how the file was made\n",
    "    primaryhdu.header[\"COMMENT\"] = \"For a given filter, and detector 1 file is produced in \"\n",
    "    primaryhdu.header[\"COMMENT\"] = \"the form [i, y, x] where i is the ePSF position on the detector grid \"\n",
    "    primaryhdu.header[\"COMMENT\"] = \"and (y,x) is the 2D PSF. The order of PSFs can be found under the \"\n",
    "    primaryhdu.header[\"COMMENT\"] = \"header DET_YX* keywords\"\n",
    "\n",
    "    hdu = fits.HDUList(primaryhdu)\n",
    "\n",
    "    filename = \"epsf_{}_{}_nepsf{}.fits\".format(det, filt, num_psfs)\n",
    "    \n",
    "    file = os.path.join(psfs_dir, filename)\n",
    "\n",
    "    hdu.writeto(file, overwrite=overwrite)\n",
    "\n",
    "\n",
    "def plot_epsf(model, num):\n",
    "\n",
    "    if num == 1:\n",
    "        plt.clf()\n",
    "        plt.figure(figsize=(10, 10))\n",
    "        ax = plt.subplot(1, 1, 1)\n",
    "\n",
    "        norm_epsf = simple_norm(model.data[0], 'log', percent=99.)\n",
    "        plt.suptitle(det + ' - ' + filt, fontsize=20)\n",
    "        plt.title(model.meta['grid_xypos'][0], fontsize=20)\n",
    "        ax.imshow(model.data[0], norm=norm_epsf)\n",
    "        plt.tight_layout()\n",
    "\n",
    "        filename = 'ePSF_single_%s_%s.pdf' % (det, filt)\n",
    "        \n",
    "        plt.savefig(os.path.join(figures_dir, filename))\n",
    " \n",
    "    else:\n",
    "        plt.clf()\n",
    "\n",
    "        nn = int(np.sqrt(num))\n",
    "        figsize = (12, 12)\n",
    "        fig, ax = plt.subplots(nn, nn, figsize=figsize)\n",
    "\n",
    "        for ix in range(nn):\n",
    "            for iy in range(nn):\n",
    "                i = ix * nn + iy\n",
    "                norm_epsf = simple_norm(model.data[i], 'log', percent=99.)\n",
    "                ax[nn - 1 - iy, ix].imshow(model.data[i], norm=norm_epsf)\n",
    "                ax[nn - 1 - iy, ix].set_title(model.meta['grid_xypos'][i], fontsize=20)\n",
    "\n",
    "        plt.suptitle(det + ' - ' + filt, fontsize=40)\n",
    "        plt.tight_layout()\n",
    "    \n",
    "        filename = 'ePSF_%dx%dgrid_%s_%s.pdf' % (nn, nn, det, filt)\n",
    "        \n",
    "        plt.savefig(os.path.join(figures_dir, filename))\n",
    "\n",
    "\n",
    "def create_model(data, meta):\n",
    "\n",
    "    ndd = NDData(data, meta=meta, copy=True)\n",
    "\n",
    "    ndd.meta['grid_xypos'] = [((float(ndd.meta[key][0].split(',')[1].split(')')[0])), \n",
    "                               (float(ndd.meta[key][0].split(',')[0].split('(')[1]))) for key in ndd.meta.keys()\n",
    "                              if \"DET_YX\" in key]\n",
    "\n",
    "    ndd.meta['oversampling'] = meta[\"OVERSAMP\"][0]\n",
    "    ndd.meta = {key.lower(): ndd.meta[key] for key in ndd.meta}\n",
    "    model = GriddedPSFModel(ndd)\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93a5c27e",
   "metadata": {},
   "outputs": [],
   "source": [
    "psfs_dir = 'PSF_MODELS/'\n",
    "\n",
    "if not os.path.exists(psfs_dir):\n",
    "    os.makedirs(psfs_dir)\n",
    "    \n",
    "data_bkgsub, _ = calc_bkg()\n",
    "\n",
    "epsf_grid = build_epsf_grid(num_psfs=4, size=11, oversample=4, save_epsf=True, savefig=True, overwrite=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5630029f-31d1-42cd-8454-225e86cabc48",
   "metadata": {},
   "source": [
    "<hr style=\"border:1px solid gray\"> </hr>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "843b5201-6f57-46f0-9da0-b738714178d3",
   "metadata": {},
   "source": [
    "<img style=\"float: center;\" src=\"https://raw.githubusercontent.com/spacetelescope/notebooks/master/assets/stsci_pri_combo_mark_horizonal_white_bkgd.png\" alt=\"Space Telescope Logo\" width=\"200px\"/>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "toc-showcode": false
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
