import os, sys
import numpy as np
import pandas as pd
import glob
import pickle
import time
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.cm as cm
import math

import lsst.daf.persistence as dafPersist
import lsst.afw.display as afwDisplay
afwDisplay.setDefaultBackend("matplotlib")

from pfs.drp.stella import SpectrumSet
from pfs.drp.stella.subtractSky1d import subtractSky1d
from pfs.drp.stella.utils import showDetectorMap
from pfs.drp.stella.utils import addPfsCursor

from pfs.utils.fiberids import FiberIds
from pfs.datamodel import PfsDesign, FiberStatus, TargetType
from scipy.stats import iqr
from scipy.optimize import curve_fit

xwin=3
ywin=3
xfit = 5
xwidth = 15

def gaussian_func(x, a,  mu, sigma):
    return a * np.exp( - (x - mu)**2 / (2 * sigma**2))

def getStatsPerCell(data, pfsConfig, detectorMap, fiberId):
    ymin=0
    ymax=data.getDimensions()[1]
    yo = np.array(np.linspace(ymin, ymax, 150)[1:-1])
    xo = detectorMap.getXCenter(fiberId, yo)
    xs = np.round(xo).astype(int)
    ys = np.round(yo).astype(int)
    image = data.image.array
    mask = data.mask.array
    image_mean = []
    image_median = []
    image_stddev = []
    for x,y in zip(xs, ys):
        image_cut = image[y-ywin:y+ywin+1, x-xwin:x+xwin+1].copy()
        mask_cut = mask[y-ywin:y+ywin+1, x-xwin:x+xwin+1].copy()
        image_mean.append(np.nanmean(image_cut[mask_cut==0]))
        image_median.append(np.nanmedian(image_cut[mask_cut==0]))
        image_stddev.append(np.nanstd(image_cut[mask_cut==0]))
    image_mean = np.array(image_mean)
    image_median = np.array(image_median)
    image_stddev = np.array(image_stddev)
    return xo, yo, xs, ys, image_mean, image_median, image_stddev

def getStatsPerFiber(data, detectorMap, fiberId):
    ymin=0
    ymax=data.getDimensions()[1]
    xmax=data.getDimensions()[0]
    yo = np.arange(ymin, ymax).astype(np.float64)
    xo = detectorMap.getXCenter(fiberId, yo)
    xs = np.round(xo).astype(int)
    ys = np.round(yo).astype(int)
    image = data.image.array
    mask = data.mask.array
    # xarray = np.tile(np.arange(-xwin, xmax+xwin), (ymax, 1))
    # xoarray = np.array([[x for _ in range(xarray.shape[1])] for x in xs])
    # imageWiden = np.concatenate([np.zeros((image.shape[0], xwin)), image, np.zeros((image.shape[0], xwin))], axis=1)
    # maskWiden = np.concatenate([np.zeros((mask.shape[0], xwin)), mask, np.zeros((image.shape[0], xwin))], axis=1)
    # imageFiber = imageWiden[(xarray >= xoarray - xwin) & (xarray <= xoarray + xwin)].reshape((ymax-ymin, 2*xwin+1))
    # maskFiber = maskWiden[(xarray >= xoarray - xwin) & (xarray <= xoarray + xwin)].reshape((ymax-ymin, 2*xwin+1))
    imageFiber = []
    maskFiber = []
    for x,y in zip(xs, ys):
        image_cut = image[y, x-xwin:x+xwin+1].copy()
        mask_cut = mask[y, x-xwin:x+xwin+1].copy()
        if x+xwin+1 > xmax:
            image_cut = np.concatenate([np.zeros(xwin*2 + 1 - len(image_cut)), image_cut])
            mask_cut = np.concatenate([np.zeros(xwin*2 + 1 - len(mask_cut)), mask_cut])
        elif x-xwin < 0:
            image_cut = np.concatenate([image_cut, np.zeros(xwin*2 + 1 - len(image_cut))])
            mask_cut = np.concatenate([mask_cut, np.zeros(xwin*2 + 1 - len(mask_cut))])
        imageFiber.append(image_cut)
        maskFiber.append(mask_cut)
    imageFiber = np.array(imageFiber)
    maskFiber = np.array(maskFiber)

    chi_square = np.average(imageFiber[maskFiber==0]**2)
    im_ave = np.average(imageFiber[maskFiber==0])

    return chi_square, im_ave, np.average(xo), imageFiber, maskFiber


def extractionQA(dataId, figDir, plotBool=True, plotNum=20):
    try:
        pfsConfig = butler.get('pfsConfig', dataId)
        configBool = True
    except:
        configBool = False
    calexp = butler.get('calexp', dataId)
    detMap = butler.get('detectorMap_used', dataId)
    profiles = butler.get("fiberProfiles", dataId)
    pfsArm = butler.get("pfsArm", dataId)

    spectra = SpectrumSet.fromPfsArm(pfsArm)
    traces = profiles.makeFiberTracesFromDetectorMap(detMap)
    image = spectra.makeImage(calexp.getDimensions(), traces)

    subtracted = calexp.clone()
    subtracted.image -= image

    divided = subtracted.clone()
    divided.image /= calexp.image

    chiimage = subtracted.clone()
    chiimage.image.array /= np.sqrt(calexp.variance.array)

    variance = subtracted.clone()
    variance.image.array = np.sqrt(calexp.variance.array)

    #     fig = 1; plt.close(fig); fig = plt.figure(fig, figsize=(10,10))
    #     disp = afwDisplay.Display(fig)
    #     #disp.scale('asinh', -50, 100, Q=2)
    #     disp.scale('asinh', 'zscale', Q=1)
    #     disp.setMaskPlaneColor("REFLINE", afwDisplay.IGNORE)
    #     disp.mtv(calexp, title=f"{'%(visit)d %(arm)s%(spectrograph)d' % dataId}")
    #     addPfsCursor(disp, detMap)
    #     if configBool:
    #         showDetectorMap(disp, pfsConfig, detMap, fiberIds=541 if dataId["spectrograph"] == 1 else 1426, width=4, alpha=0.8)
    #     fig.savefig(figDir + '/%(visit)d_%(arm)s%(spectrograph)d_calexp.pdf' % dataId, bbox_inches='tight')

    #     #axe = fig.add_subplot()
    #     #axe.scatter([100, 100], [100, 100], marker='o', s=100, color='C0')

    #     fig = 2; plt.close(fig); fig = plt.figure(fig, figsize=(10,10))
    #     disp = afwDisplay.Display(fig)
    #     #disp.scale('asinh', -50, 100, Q=2)
    #     disp.scale('asinh', 'zscale', Q=1)
    #     disp.setMaskPlaneColor("REFLINE", afwDisplay.IGNORE)
    #     disp.mtv(subtracted, title=f"{'%(visit)d %(arm)s%(spectrograph)d' % dataId}")
    #     addPfsCursor(disp, detMap)
    #     if configBool:
    #         showDetectorMap(disp, pfsConfig, detMap, fiberIds=541 if dataId["spectrograph"] == 1 else 1426, width=4, alpha=0.8)
    #     fig.savefig(figDir + '/%(visit)d_%(arm)s%(spectrograph)d_subtracted.pdf' % dataId, bbox_inches='tight')

    #     fig = 3; plt.close(fig); fig = plt.figure(fig, figsize=(10,10))
    #     disp = afwDisplay.Display(fig)
    #     #disp.scale('asinh', -1, 1, Q=1)
    #     disp.scale('asinh', 'zscale', Q=1)
    #     disp.setMaskPlaneColor("REFLINE", afwDisplay.IGNORE)
    #     disp.mtv(divided, title=f"{'%(visit)d %(arm)s%(spectrograph)d' % dataId}")
    #     addPfsCursor(disp, detMap)
    #     if configBool:
    #         showDetectorMap(disp, pfsConfig, detMap, fiberIds=541 if dataId["spectrograph"] == 1 else 1426, width=4, alpha=0.8)
    #     fig.savefig(figDir + '/%(visit)d_%(arm)s%(spectrograph)d_subtracted-divided.pdf' % dataId, bbox_inches='tight')

    #     fig = 4; plt.close(fig); fig = plt.figure(fig, figsize=(10,10))
    #     disp = afwDisplay.Display(fig)
    #     #disp.scale('asinh', -1, 1, Q=1)
    #     disp.scale('linear', -5, 5, Q=1)
    #     disp.setImageColormap('coolwarm')
    #     disp.setMaskPlaneColor("REFLINE", afwDisplay.IGNORE)
    #     disp.mtv(chiimage, title=f"{'%(visit)d %(arm)s%(spectrograph)d' % dataId}")
    #     addPfsCursor(disp, detMap)
    #     # showDetectorMap(disp, pfsConfig, detMap, fiberIds=541 if dataId["spectrograph"] == 1 else 1426, width=4, alpha=0.8)
    #     fig.savefig(figDir + '/%(visit)d_%(arm)s%(spectrograph)d_chi.pdf' % dataId, bbox_inches='tight')

    #     fig = 5; plt.close(fig); fig = plt.figure(fig, figsize=(10,10))
    #     disp = afwDisplay.Display(fig)
    #     #disp.scale('asinh', -1, 1, Q=1)
    #     disp.scale('asinh', 'zscale', Q=1)
    #     disp.setMaskPlaneColor("REFLINE", afwDisplay.IGNORE)
    #     disp.mtv(variance, title=f"{'%(visit)d %(arm)s%(spectrograph)d' % dataId}")
    #     addPfsCursor(disp, detMap)
    #     if configBool:
    #         showDetectorMap(disp, pfsConfig, detMap, fiberIds=541 if dataId["spectrograph"] == 1 else 1426, width=4, alpha=0.8)
    #     fig.savefig(figDir + '/%(visit)d_%(arm)s%(spectrograph)d_variancesqrt.pdf' % dataId, bbox_inches='tight')

    targetMask = {}
    if configBool:
        msk = (pfsConfig.spectrograph == spectrograph) * (pfsConfig.fiberStatus == FiberStatus.GOOD)
        fiberIds = pfsConfig[msk].fiberId
        for t in TargetType:
            targetMask[t] = pfsConfig[msk].targetType == t
    else:
        fiberIds = pfsArm.fiberId
        targetMask["No config"] = np.ones(fiberIds.shape, dtype=bool)

    data = chiimage.maskedImage
    caldata = calexp.maskedImage
    xa = []
    ya = []
    chiSquare = []
    chiAve = []
    chiMedian = []
    chiStd = []
    chiPeak = []
    chiShift = []
    pfsArmAve = []
    chiShiftSpec = []
    chiAveSpec = []

    colors = ['r', 'b', 'g', 'y', 'gray', 'orange', 'cyan', 'k', 'r', 'b', 'g', 'y', 'gray', 'orange', 'cyan', 'k']
    ct = {}
    i = 0
    for t in targetMask.keys():
        ct[t] = colors[i]
        i += 1

    maskPlaneDict = {'BAD': 0, 'BAD_FIBERTRACE': 11, 'BAD_FLAT': 9, 'BAD_FLUXCAL': 13, 'BAD_SKY': 12,
                     'CR': 3, 'DETECTED': 5, 'DETECTED_NEGATIVE': 6, 'EDGE': 4, 'FIBERTRACE': 10, 'INTRP': 2,
                     'IPC': 14, 'NO_DATA': 8, 'REFLINE': 15, 'SAT': 1, 'SUSPECT': 7, 'UNMASKEDNAN': 16}
    maskKeys = maskPlaneDict.keys()

    for fiberId in fiberIds:
        chi2, _, x_ave, chi_f, mask_f = getStatsPerFiber(data, detMap, fiberId=fiberId)
        # _, r_ave, x_ave2, cal_f, _ = getStatsPerFiber(caldata, detMap, fiberId=fiberId)

        chi_f[mask_f != 0] = float('nan')

        chiSquare.append(chi2)
        chiAve.append(np.average(chi_f[mask_f == 0]))
        chiMedian.append(np.median(chi_f[mask_f == 0]))
        chiStd.append(iqr(chi_f[mask_f == 0]) / 1.349)
        chiPeak.append(np.average(chi_f[:, 3][mask_f[:, 3] == 0]))
        chiShift.append(np.average(chi_f[:, 4:7][mask_f[:, 4:7] == 0]) - np.average(chi_f[:, 0:3][mask_f[:, 0:3] == 0]))
        pfsArmAve.append(np.average(pfsArm[pfsArm.fiberId==fiberId].flux[0]))
        chiShiftSpec.append(np.nanmean(chi_f[:, 4:7], axis=1) - np.nanmean(chi_f[:, 0:3], axis=1))
        chiAveSpec.append(np.average(chi_f, axis=1))
        xa.append(x_ave)

    i = 0
    ymin = 0
    ymax = data.getDimensions()[1]
    yo = np.arange(ymin, ymax).astype(np.float64)
    thresPlot = max(1.2, sorted(chiStd)[-plotNum])
    thres = 1.2
    xarray = []
    yarray = []
    dxarray = []
    dwarray = []
    if any(np.array(chiStd) > thres):
        if plotBool == True:
            pp = PdfPages(figDir + '/%(visit)d_%(arm)s%(spectrograph)d_chifiber.pdf' % dataId)
        for fiberId in fiberIds:
            if chiStd[i] >= thres or pfsArmAve[i] > 1000:
                print(fiberId, chiStd[i])
                xo = detMap.getXCenter(fiberId, yo)
                if plotBool == True and chiStd[i] >= thresPlot:
                    fig, ax = plt.subplots(1, 6, figsize=(12, 7))
                    plt.subplots_adjust(wspace=0.5)
                    plt.sca(ax[0])
                    disp = afwDisplay.Display(fig)
                    disp.scale('asinh', 'zscale', Q=1)
                    disp.setMaskPlaneColor("REFLINE", afwDisplay.IGNORE)
                    disp.mtv(calexp[int(xa[i]) - 10: int(xa[i]) + 11,:])
                    ax[0].plot(xo, yo, "r", alpha=0.8)
                    ax[0].set_xlim(xa[i] - 10, xa[i] + 10)
                    ax[0].set_aspect("auto")
                    ax[0].set_ylabel("Y (pix)")
                    ax[0].set_title("calexp")

                    plt.sca(ax[1])
                    disp = afwDisplay.Display(fig)
                    disp.setMaskPlaneColor("REFLINE", afwDisplay.BLACK)
                    disp.setImageColormap('coolwarm')
                    disp.scale('asinh', 'zscale', Q=1)
                    disp.mtv(subtracted[int(xa[i]) - 10: int(xa[i]) + 11,:])
                    ax[1].plot(xo, yo, "k", alpha=0.8)
                    ax[1].set_xlim(xa[i] - 10, xa[i] + 10)
                    ax[1].set_aspect("auto")
                    ax[1].tick_params(labelbottom=True, labelleft=False, labelright=False, labeltop=False)
                    ax[1].set_xlabel("X (pix)")
                    ax[1].set_title("Residual")

                    plt.sca(ax[2])
                    disp = afwDisplay.Display(fig)
                    disp.scale('linear', -5, 5, Q=1)
                    disp.setMaskPlaneColor("REFLINE", afwDisplay.BLACK)
                    disp.setImageColormap('coolwarm')
                    disp.mtv(chiimage[int(xa[i]) - 10: int(xa[i]) + 11,:])
                    ax[2].plot(xo, yo, "k", alpha=0.8)
                    ax[2].set_xlim(xa[i] - 10, xa[i] + 10)
                    ax[2].set_aspect("auto")
                    ax[2].tick_params(labelbottom=True, labelleft=False, labelright=False, labeltop=False)
                    ax[2].set_title("Chi")

                    ax[3].scatter(chiAveSpec[i], yo, s=3)
                    ax[3].set_ylim(ymin, ymax)
                    ax[3].plot([0, 0], [ymin, ymax], "0.8")
                    ax[3].set_xlabel("Chi at each row")
                    ax[3].tick_params(labelbottom=True, labelleft=False, labelright=False, labeltop=False)
                    ax[3].set_title("Chi")

                ys = np.arange(ymin, ymax, (ymax-ymin)/200).astype("int32")
                xint = xo.astype("int32")
                centerdif = []
                widthdif = []
                ydif = []
                failNum = 0
                for j in range(200):
                    yssub = ys[j]
                    xssub = xint[yssub]
                    # xcoord = np.arange(xssub-xwidth, xssub+xwidth+1)
                    # pfsArmCut = image.array[yssub, xssub-xwidth:xssub+xwidth+1]
                    # calExpCut = calexp.image.array[yssub, xssub-xwidth:xssub+xwidth+1]
                    xcoordNarrow = np.arange(max(xssub - xfit, 0), min(xssub + xfit + 1, data.getDimensions()[0]))
                    pfsArmCutNarrow = image.array[yssub, xssub - xfit:xssub + xfit + 1]
                    calExpCutNarrow = calexp.image.array[yssub, xssub - xfit:xssub + xfit + 1]
                    pfsArmCenter0 = np.sum(pfsArmCutNarrow * xcoordNarrow) / np.sum(pfsArmCutNarrow)
                    calExpCenter0 = np.sum(calExpCutNarrow * xcoordNarrow) / np.sum(calExpCutNarrow)
                    # pfsArmWidth0 = (np.sum(pfsArmCutNarrow * (xcoordNarrow - pfsArmCenter0) ** 2) / np.sum(
                    #     pfsArmCutNarrow)) ** 0.5
                    # calExpWidth0 = (np.sum(calExpCutNarrow * (xcoordNarrow - calExpCenter0) ** 2) / np.sum(
                    #     calExpCutNarrow)) ** 0.5
                    try:
                        poptPfsArm, pcovPfsArm = curve_fit(gaussian_func, xcoordNarrow, pfsArmCutNarrow,
                                                           p0=np.array([np.max(pfsArmCutNarrow), np.median(xcoordNarrow),
                                                                       1.]))
                        poptCalExp, pcovCalExp = curve_fit(gaussian_func, xcoordNarrow, calExpCutNarrow,
                                                           p0=np.array([np.max(calExpCutNarrow), np.median(xcoordNarrow),
                                                                       1.]))
                        stdErrPfsArm = np.sqrt(np.diag(pcovPfsArm))
                        stdErrCalExp = np.sqrt(np.diag(pcovCalExp))
                        if stdErrPfsArm[1] / poptPfsArm[1] < 0.1 and stdErrPfsArm[2] / poptPfsArm[2] < 0.1 and \
                                stdErrCalExp[1] / poptCalExp[1] < 0.1 and stdErrCalExp[2] / poptCalExp[2] < 0.1:
                            pfsArmCenter, pfsArmWidth = poptPfsArm[1], poptPfsArm[2]
                            calExpCenter, calExpWidth = poptCalExp[1], poptCalExp[2]
                            centerdif.append(calExpCenter - pfsArmCenter)
                            ydif.append(yssub)
                            widthdif.append((calExpWidth - pfsArmWidth) / calExpWidth)
                            xarray.append(xssub)
                            yarray.append(yssub)
                        else:
                            failNum += 1
                    except:
                        failNum += 1

                dxarray += centerdif
                dwarray += widthdif

                if plotBool == True and chiStd[i] >= thresPlot:
                    ax[4].scatter(centerdif, ydif, s=3)
                    ax[4].set_ylim(ymin, ymax)
                    ax[4].plot([0, 0], [ymin, ymax], "0.8")
                    ax[4].set_xlabel("dx")
                    ax[4].tick_params(labelbottom=True, labelleft=False, labelright=False, labeltop=False)
                    ax[4].set_title("Peak center diff.")

                    ax[5].scatter(widthdif, ydif, s=3)
                    ax[5].set_ylim(ymin, ymax)
                    ax[5].plot([0, 0], [ymin, ymax], "0.8")
                    ax[5].set_xlabel("d$\sigma$/$\sigma$")
                    ax[5].tick_params(labelbottom=True, labelleft=False, labelright=True, labeltop=False)
                    ax[5].set_title("Width diff.")

                    # ax[3].scatter(chiShiftSpec[i], yo, s=3)
                    # ax[3].set_ylim(ymin, ymax)
                    # ax[3].plot([0, 0], [ymin, ymax], "0.8")
                    # ax[3].set_xlabel("Chi+ - Chi-")
                    # ax[3].tick_params(labelbottom=True, labelleft=False, labelright=False, labeltop=False)
                    # ax[3].set_title("Shift measure")

                    fig.suptitle("visit={:d} arm={:s} spectrograph={:d}\nf={:d}, X={:.1f}".format(dataId["visit"], dataId["arm"], dataId["spectrograph"], fiberId, xa[i]), fontsize=12)

                    plt.savefig(pp, format="pdf")
                    plt.close(fig)

                    ys = np.arange(ymin, ymax, (ymax-ymin)/(4*4+1)).astype("int32")
                    xint = xo.astype("int32")
                    fig, ax = plt.subplots(4, 4, figsize=(12, 7))
                    for j in range(4):
                        for k in range(4):
                            yssub = ys[k+j*4+1]
                            xssub = xint[yssub]
                            xcoord = np.arange(max(xssub-xwidth, 0), min(xssub+xwidth+1, data.getDimensions()[0]))
                            pfsArmCut = image.array[yssub, xssub-xwidth:xssub+xwidth+1]
                            calExpCut = calexp.image.array[yssub, xssub-xwidth:xssub+xwidth+1]
                            xcoordNarrow = np.arange(max(xssub - xfit, 0),
                                                     min(xssub + xfit + 1, data.getDimensions()[0]))
                            pfsArmCutNarrow = image.array[yssub, xssub - xfit:xssub + xfit + 1]
                            calExpCutNarrow = calexp.image.array[yssub, xssub - xfit:xssub + xfit + 1]
                            pfsArmCenter0 = np.sum(pfsArmCutNarrow * xcoordNarrow) / np.sum(pfsArmCutNarrow)
                            calExpCenter0 = np.sum(calExpCutNarrow * xcoordNarrow) / np.sum(calExpCutNarrow)
                            pfsArmWidth0 = (np.sum(pfsArmCutNarrow * (xcoordNarrow - pfsArmCenter0) ** 2) / np.sum(pfsArmCutNarrow))**0.5
                            calExpWidth0 = (np.sum(calExpCutNarrow * (xcoordNarrow - calExpCenter0) ** 2) / np.sum(calExpCutNarrow))**0.5
                            try:
                                poptPfsArm, pcovPfsArm = curve_fit(gaussian_func, xcoordNarrow, pfsArmCutNarrow, p0=np.array([np.max(pfsArmCutNarrow), pfsArmCenter0, pfsArmWidth0]))
                                poptCalExp, pcovCalExp = curve_fit(gaussian_func, xcoordNarrow, calExpCutNarrow, p0=np.array([np.max(calExpCutNarrow), calExpCenter0, calExpWidth0]))
                                stdErrPfsArm = np.sqrt(np.diag(pcovPfsArm))
                                stdErrCalExp = np.sqrt(np.diag(pcovCalExp))
                                if stdErrPfsArm[1]/poptPfsArm[1] < 0.1 and stdErrPfsArm[2]/poptPfsArm[2] < 0.1 and stdErrCalExp[1]/poptCalExp[1] < 0.1 and stdErrCalExp[2]/poptCalExp[2] < 0.1:
                                    pfsArmCenter, pfsArmWidth = poptPfsArm[1], poptPfsArm[2]
                                    calExpCenter, calExpWidth = poptCalExp[1], poptCalExp[2]
                                else:
                                    pfsArmCenter, pfsArmWidth = math.nan, math.nan
                                    calExpCenter, calExpWidth = math.nan, math.nan
                            except:
                                pfsArmCenter, pfsArmWidth = math.nan, math.nan
                                calExpCenter, calExpWidth = math.nan, math.nan

                            ax[j][k].plot(xcoord, np.zeros(xcoord.shape), "k--")
                            ax[j][k].step(xcoord, pfsArmCut, label="pfsArm\n(x={:.2f}, $\sigma$={:.2f})".format(pfsArmCenter, pfsArmWidth), color="b")
                            ax[j][k].step(xcoord, calExpCut, label="calExp\n(x={:.2f}, $\sigma$={:.2f})".format(calExpCenter, calExpWidth), color="k")
                            ax[j][k].step(xcoord, subtracted.image.array[yssub, xssub-xwidth:xssub+xwidth+1]*5, label="Residual*5", color="r")
                            ypeak = np.amax(image.array[yssub, xssub-xfit:xssub+xfit+1])
                            ax[j][k].plot([xo[yssub], xo[yssub]], [-ypeak/10*3, ypeak*1.5], "b--", label="Trace")
                            ax[j][k].set_ylim(-ypeak/10*3, ypeak*1.5)
                            ax[j][k].set_title("Y={} (dx={:.1e}, d$\sigma$={:.1e})".format(yssub, calExpCenter-pfsArmCenter, calExpWidth-pfsArmWidth), fontsize=8)
                            ax[j][k].legend(fontsize=4)
                            # if j==0 and k==3:
                            #     ax[j][k].legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0, fontsize=6)
                            labelbottom = False if j!=3 else True
                            ax[j][k].tick_params(labelbottom=labelbottom, labelleft=False, labelright=False, labeltop=False)
                    fig.suptitle("visit={:d} arm={:s} spectrograph={:d}\nf={:d}, X={:.1f}".format(dataId["visit"], dataId["arm"], dataId["spectrograph"], fiberId, xa[i]), fontsize=12)
                    plt.savefig(pp, format="pdf")
                    plt.close(fig)

            i += 1
        if plotBool == True:
            pp.close()

    pfsArmAve = np.array(pfsArmAve)
    chiSquare = np.array(chiSquare)
    chiAve = np.array(chiAve)
    chiMedian = np.array(chiMedian)
    chiStd = np.array(chiStd)
    chiPeak = np.array(chiPeak)
    chiShift = np.array(chiShift)
    xa = np.array(xa)

    stats = {"fiberIds": fiberIds, "xa": xa, "pfsArmAve": pfsArmAve, "targetMask": targetMask, "chiSquare":chiSquare,
             "chiAve":chiAve, "chiMedian": chiMedian, "chiStd": chiStd, "chiPeak": chiPeak, "chiShift":chiShift,
             "Xarray": xarray, "Yarray": yarray, "dx": dxarray, "dsigma": dwarray}
    statsPickle = open(figDir + '/%(visit)d_%(arm)s%(spectrograph)d_chifiber.pickle' % dataId, "wb")
    pickle.dump(stats, statsPickle)
    statsPickle.close()

    aveRange = 5.
    medRange = 5.
    stdRange = 10.
    ql = 0.2
    largeAve = chiAve > aveRange
    smallAve = chiAve < -aveRange
    largeMed = chiMedian > medRange
    smallMed = chiMedian < -medRange
    largeStd = chiStd > stdRange

    pp = PdfPages(figDir + '/%(visit)d_%(arm)s%(spectrograph)d_chistats.pdf' % dataId)

    fig, ax = plt.subplots(1, 2, figsize=(16, 7))
    for t in targetMask.keys():
        if np.sum(targetMask[t]) > 0:
            ax[0].scatter(fiberIds[targetMask[t]], pfsArmAve[targetMask[t]], 10., c=ct[t],
                          label='{}: {} fibers'.format(t, np.sum(targetMask[t])))
    ax[0].set_ylabel("Mean flux of pfsArm")
    ax[0].set_xlabel("fiberId")
    ax[0].set_yscale("log")
    ax[0].legend()
    for t in targetMask.keys():
        if np.sum(targetMask[t]) > 0:
            ax[1].scatter(xa[targetMask[t]], pfsArmAve[targetMask[t]], 10., c=ct[t],
                          label='{}: {} fibers'.format(t, np.sum(targetMask[t])))
    ax[1].set_xlabel("X (pix)")
    ax[1].set_yscale("log")
    ax[1].legend()
    fig.suptitle("visit=%(visit)d arm=%(arm)s spectrograph=%(spectrograph)d" % dataId)
    plt.savefig(pp, format="pdf")
    plt.close(fig)

    fig, ax = plt.subplots(1, 2, figsize=(16, 7))
    ax[0].plot([np.amin(fiberIds), np.amax(fiberIds)], [0, 0], 'gray')
    for t in targetMask.keys():
        if np.sum(targetMask[t]) > 0:
            ax[0].scatter(fiberIds[targetMask[t]], chiAve[targetMask[t]], 10., c=ct[t],
                          label='{}: {} fibers'.format(t, np.sum(targetMask[t])))
    ax[0].set_ylabel("Chi (average)")
    ax[0].set_xlabel("fiberId")
    ax[0].legend()
    if np.sum(largeAve) > 0:
        ax[0].quiver(fiberIds[largeAve], np.zeros(np.sum(largeAve)) + aveRange - aveRange * ql, 0., aveRange * ql,
                     color="k", angles="xy", scale_units="xy", scale=1, width=0.005)
    if np.sum(smallAve) > 0:
        ax[0].quiver(fiberIds[smallAve], np.zeros(np.sum(smallAve)) - aveRange + aveRange * ql, 0., -aveRange * ql,
                     color="k", angles="xy", scale_units="xy", scale=1, width=0.005)
    ax[0].set_ylim(-aveRange, aveRange)
    ax[1].plot([np.amin(xa), np.amax(xa)], [0, 0], 'gray')
    for t in targetMask.keys():
        if np.sum(targetMask[t]) > 0:
            ax[1].scatter(xa[targetMask[t]], chiAve[targetMask[t]], 10., c=ct[t],
                          label='{}: {} fibers'.format(t, np.sum(targetMask[t])))
    ax[1].set_xlabel("X (pix)")
    if np.sum(largeAve) > 0:
        ax[1].quiver(xa[largeAve], np.zeros(np.sum(largeAve)) + aveRange - aveRange * ql, 0., aveRange * ql,
                     color="k", angles="xy", scale_units="xy", scale=1, width=0.005)
    if np.sum(smallAve) > 0:
        ax[1].quiver(xa[smallAve], np.zeros(np.sum(smallAve)) - aveRange + aveRange * ql, 0., -aveRange * ql,
                     color="k", angles="xy", scale_units="xy", scale=1, width=0.005)
    ax[1].set_ylim(-aveRange, aveRange)
    ax[1].legend()
    fig.suptitle("visit=%(visit)d arm=%(arm)s spectrograph=%(spectrograph)d" % dataId)
    plt.savefig(pp, format="pdf")
    plt.close(fig)

    fig, ax = plt.subplots(1, 2, figsize=(16, 7))
    ax[0].plot([np.amin(fiberIds), np.amax(fiberIds)], [0, 0], 'gray')
    for t in targetMask.keys():
        if np.sum(targetMask[t]) > 0:
            ax[0].scatter(fiberIds[targetMask[t]], chiMedian[targetMask[t]], 10., c=ct[t],
                          label='{}: {} fibers'.format(t, np.sum(targetMask[t])))
    ax[0].set_ylabel("Chi (median)")
    ax[0].set_xlabel("fiberId")
    ax[0].legend()
    if np.sum(largeMed) > 0:
        ax[0].quiver(fiberIds[largeMed], np.zeros(np.sum(largeMed)) + medRange - medRange * ql, 0., medRange * ql,
                     color="k", angles="xy", scale_units="xy", scale=1, width=0.005)
    if np.sum(smallMed) > 0:
        ax[0].quiver(fiberIds[smallMed], np.zeros(np.sum(smallMed)) - medRange + medRange * ql, 0., -medRange * ql,
                     color="k", angles="xy", scale_units="xy", scale=1, width=0.005)
    ax[0].set_ylim(-medRange, medRange)
    ax[1].plot([np.amin(xa), np.amax(xa)], [0, 0], 'gray')
    for t in targetMask.keys():
        if np.sum(targetMask[t]) > 0:
            ax[1].scatter(xa[targetMask[t]], chiMedian[targetMask[t]], 10., c=ct[t],
                          label='{}: {} fibers'.format(t, np.sum(targetMask[t])))
    ax[1].set_xlabel("X (pix)")
    if np.sum(largeMed) > 0:
        ax[1].quiver(xa[largeMed], np.zeros(np.sum(largeMed)) + medRange - medRange * ql, 0., medRange * ql,
                     color="k", angles="xy", scale_units="xy", scale=1, width=0.005)
    if np.sum(smallMed) > 0:
        ax[1].quiver(xa[smallMed], np.zeros(np.sum(smallMed)) - medRange + medRange * ql, 0., -medRange * ql,
                     color="k", angles="xy", scale_units="xy", scale=1, width=0.005)
    ax[1].set_ylim(-medRange, medRange)
    fig.suptitle("visit=%(visit)d arm=%(arm)s spectrograph=%(spectrograph)d" % dataId)
    plt.savefig(pp, format="pdf")
    plt.close(fig)

    fig, ax = plt.subplots(1, 2, figsize=(16, 7))
    ax[0].plot([np.amin(fiberIds), np.amax(fiberIds)], [1, 1], 'gray')
    for t in targetMask.keys():
        if np.sum(targetMask[t]) > 0:
            ax[0].scatter(fiberIds[targetMask[t]], chiStd[targetMask[t]], 10., c=ct[t],
                          label='{}: {} fibers'.format(t, np.sum(targetMask[t])))
    ax[0].set_ylabel("Chi (standard deviation)")
    ax[0].set_xlabel("fiberId")
    ax[0].legend()
    if np.sum(largeStd) > 0:
        ax[0].quiver(fiberIds[largeStd], np.zeros(np.sum(largeStd)) + stdRange - stdRange * ql / 2, 0.,
                     stdRange * ql / 2,
                     color="k", angles="xy", scale_units="xy", scale=1, width=0.005)
    print("visit=%(visit)d arm=%(arm)s spectrograph=%(spectrograph)d: " % dataId, np.nanmean(chiStd))
    ax[0].set_ylim(-stdRange * 0.1, stdRange)
    ax[1].plot([np.amin(xa), np.amax(xa)], [1, 1], 'gray')
    for t in targetMask.keys():
        if np.sum(targetMask[t]) > 0:
            ax[1].scatter(xa[targetMask[t]], chiStd[targetMask[t]], 10., c=ct[t],
                          label='{}: {} fibers'.format(t, np.sum(targetMask[t])))
    ax[1].set_xlabel("X (pix)")
    if np.sum(largeStd) > 0:
        ax[1].quiver(xa[largeStd], np.zeros(np.sum(largeStd)) + stdRange - stdRange * ql / 2, 0., stdRange * ql / 2,
                     color="k", angles="xy", scale_units="xy", scale=1, width=0.005)
    ax[1].set_ylim(-stdRange * 0.1, stdRange)
    ax[1].legend()
    fig.suptitle("visit=%(visit)d arm=%(arm)s spectrograph=%(spectrograph)d" % dataId)
    plt.savefig(pp, format="pdf")
    plt.close(fig)

    fig, ax = plt.subplots(1, 2, figsize=(16, 7))
    for t in targetMask.keys():
        if np.sum(targetMask[t]) > 0:
            ax[0].scatter(fiberIds[targetMask[t]], chiSquare[targetMask[t]], 10., c=ct[t],
                          label='{}: {} fibers'.format(t, np.sum(targetMask[t])))
    ax[0].set_ylabel("Chi^2 (average)")
    ax[0].set_xlabel("fiberId")
    ax[0].set_yscale("log")
    ax[0].legend()
    for t in targetMask.keys():
        if np.sum(targetMask[t]) > 0:
            ax[1].scatter(xa[targetMask[t]], chiSquare[targetMask[t]], 10., c=ct[t],
                          label='{}: {} fibers'.format(t, np.sum(targetMask[t])))
    ax[1].set_ylabel("Chi^2 (average)")
    ax[1].set_xlabel("X (pix)")
    ax[1].set_yscale("log")
    ax[1].legend()
    fig.suptitle("visit=%(visit)d arm=%(arm)s spectrograph=%(spectrograph)d" % dataId)
    plt.savefig(pp, format="pdf")
    plt.close(fig)

    fig, ax = plt.subplots(1, 2, figsize=(16, 7))
    for t in targetMask.keys():
        if np.sum(targetMask[t]) > 0:
            ax[0].scatter(fiberIds[targetMask[t]], chiPeak[targetMask[t]], 10., c=ct[t],
                          label='{}: {} fibers'.format(t, np.sum(targetMask[t])))
    ax[0].set_ylabel("Chi at profile peak pixels")
    ax[0].set_xlabel("fiberId")
    ax[0].legend()
    for t in targetMask.keys():
        if np.sum(targetMask[t]) > 0:
            ax[1].scatter(xa[targetMask[t]], chiPeak[targetMask[t]], 10., c=ct[t],
                          label='{}: {} fibers'.format(t, np.sum(targetMask[t])))
    ax[1].set_xlabel("X (pix)")
    fig.suptitle("visit=%(visit)d arm=%(arm)s spectrograph=%(spectrograph)d" % dataId)
    plt.savefig(pp, format="pdf")
    plt.close(fig)

    fig, ax = plt.subplots(1, 2, figsize=(16, 7))
    for t in targetMask.keys():
        if np.sum(targetMask[t]) > 0:
            ax[0].scatter(fiberIds[targetMask[t]], chiShift[targetMask[t]], 10., c=ct[t],
                          label='{}: {} fibers'.format(t, np.sum(targetMask[t])))
    ax[0].set_ylabel("Chi (sum of dx=1,2,3 pix) - Chi (sum of dx=-1,-2,-3 pix)")
    ax[0].set_xlabel("fiberId")
    ax[0].legend()
    for t in targetMask.keys():
        if np.sum(targetMask[t]) > 0:
            ax[1].scatter(xa[targetMask[t]], chiShift[targetMask[t]], 10., c=ct[t],
                          label='{}: {} fibers'.format(t, np.sum(targetMask[t])))
    ax[1].set_xlabel("X (pix)")
    fig.suptitle("visit=%(visit)d arm=%(arm)s spectrograph=%(spectrograph)d" % dataId)
    plt.savefig(pp, format="pdf")
    plt.close(fig)
    plt.clf()

    fig = plt.figure(figsize=(10,10))
    disp = afwDisplay.Display(fig)
    disp.scale('linear', -5, 5, Q=1)
    disp.setImageColormap('coolwarm')
    disp.setMaskPlaneColor("REFLINE", afwDisplay.IGNORE)
    disp.mtv(chiimage, title=f"{'%(visit)d %(arm)s%(spectrograph)d' % dataId}")
    addPfsCursor(disp, detMap)
    fig.savefig(pp, format="pdf", bbox_inches='tight')
    plt.close(fig)

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(1, 1, 1)
    mappable = ax.scatter(xarray, yarray, s=2, c=dxarray, cmap="coolwarm", vmin=-0.3, vmax=0.3)
    plt.title("dX (pix)")
    plt.xlabel("X (pix)")
    plt.ylabel("Y (pix)")
    plt.xlim(0, data.getDimensions()[0])
    plt.ylim(0, data.getDimensions()[1])
    fig.colorbar(mappable, ax=ax)
    plt.savefig(pp, format="pdf", bbox_inches='tight')
    plt.close(fig)

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(1, 1, 1)
    mappable = ax.scatter(xarray, yarray, s=2, c=dwarray, cmap="coolwarm", vmin=-0.1, vmax=0.1)
    plt.title("d$\sigma$/$\sigma$")
    plt.xlabel("X (pix)")
    plt.ylabel("Y (pix)")
    plt.xlim(0, data.getDimensions()[0])
    plt.ylim(0, data.getDimensions()[1])
    fig.colorbar(mappable, ax=ax)
    plt.savefig(pp, format="pdf", bbox_inches='tight')
    plt.close(fig)

    pp.close()

    return np.nanmean(chiStd), np.nanmedian(chiStd)

def drawFiberResidual(ax, data, pfsConfig, detectorMap, fiberId):
    ymin=0
    ymax=data.getDimensions()[1]
    yo = np.arange(ymin, ymax).astype(np.float64)
    xo = detectorMap.getXCenter(fiberId, yo)
    xs = np.round(xo).astype(int)
    ys = np.round(yo).astype(int)
    image = data.image.array
    mask = data.mask.array
    image_mean = []
    image_median = []
    image_stddev = []
    for x,y,xc in zip(xs, ys, xo):
        ax.scatter(np.arange(x-xwin, x+xwin+1, 1) - xc, image[y,x-xwin:x+xwin+1], s=3, c="b")
    return ax


if __name__ == '__main__':
    start = time.time()
    repoDir = '/work/drp'
    pfsDesignDir = '/work/drp/pfsDesign'
    rerunInput = sys.argv[1]
    # rerun = os.path.join(repoDir, 'rerun', f'pfs/internal/calibs-20230928')
    rerun = os.path.join(repoDir, 'rerun', f'pfs/internal/%s' % (rerunInput))
    # rerun = os.path.join(repoDir, 'rerun', f'pfs/internal/test-subset-20230803')
    calibRoot = os.path.join(repoDir, 'CALIB')
    butler = dafPersist.Butler(rerun, calibRoot=calibRoot)

    flist = glob.glob("/work/drp/rerun/pfs/internal/%s/DETECTORMAP/*098627*r3*.fits" % rerunInput)
    flist.sort()
    chistd = {}
    chistdmed = {}
    figDir = "figures-%s" % rerunInput
    # wf = open(figDir + '/fiberResidualChi.dat', 'w')
    for f in flist:
        visit = int(f.split("/")[-1].split("-")[1])
        arm = f.split("/")[-1].split("-")[2][0]
        spectrograph = int(f.split("/")[-1].split("-")[2][1])
        dataId = dict(visit=visit, spectrograph=spectrograph, arm=arm)
        sp = "%(arm)s%(spectrograph)d" % dataId
        # if not os.path.exists(figDir + '/%(visit)d_%(arm)s%(spectrograph)d_chifiber.pickle' % dataId):
        try:
            csa, csm = extractionQA(dataId, figDir, plotBool=True)
            if not sp in chistd.keys():
                chistd[sp] = [[visit], [csa]]
                chistdmed[sp] = [[visit], [csm]]
            else:
                chistd[sp][0].append(visit)
                chistd[sp][1].append(csa)
                chistdmed[sp][0].append(visit)
                chistdmed[sp][1].append(csm)
        except:
            print('error in {}'.format(f))
        print("extractionQA: {}sec".format(time.time()-start))

    # wf.close()
    print(chistd)