#!/usr/bin/env python

from collections import defaultdict

import numpy as np

from lsst.daf.persistence import Butler


dataDir = "/projects/HSC/PFS/Subaru/rerun/price/pipe2d-618/arc"
visits = [18232, 18233, 18234, 18235, 18249, 18250, 18251, 18252, 18267, 18268, 18269, 18270, 18285, 18286, 18287, 18288, 18303, 18304, 18305, 18306, 18321, 18322, 18323, 18324, 18339, 18340, 18341, 18342, 18357, 18358, 18359, 18360, 18375, 18376, 18377, 18378, 18393, 18394, 18395, 18396, 18411, 18412, 18413, 18414, 18429, 18430, 18431, 18432, 18447, 18448, 18449, 18450, 18465, 18466, 18467, 18468, 18483, 18484, 18485, 18486, 18501, 18502, 18503, 18504, 18519, 18520, 18521, 18522, 18537, 18538, 18539, 18540, 18556, 18557, 18558, 18559, 18574, 18575, 18576, 18577, 18592, 18593, 18594, 18595, 18610, 18611, 18612, 18613, 18628, 18629, 18630, 18631, 18646, 18647, 18648, 18649, 18664, 18665, 18666, 18667, 18682, 18683, 18684, 18685, 18700, 18701, 18702, 18703, 18718, 18719, 18720, 18721, 18736, 18737, 18738, 18739, 18754, 18755, 18756, 18757, 18772, 18773, 18774, 18775, 18790, 18791, 18792, 18793, 18808, 18809, 18810, 18811, 18826, 18827, 18828, 18829, 18844, 18845, 18846, 18847, 18862, 18863, 18864, 18865, 18880, 18881, 18882, 18883, 18898, 18899, 18900, 18901, 18916, 18917, 18918, 18919, 18934, 18935, 18936, 18937, 18955, 18956, 18957, 18958, 18973, 18974, 18975, 18976, 19005, 19006, 19007, 19008, 19023, 19024, 19025, 19026, 19041, 19042, 19043, 19044, 19059, 19060, 19061, 19062, 19077, 19078, 19079, 19080, 19095, 19096, 19097, 19098, 19113, 19114, 19115, 19116, 19131, 19132, 19133, 19134, 19149, 19150, 19151, 19152, 19167, 19168, 19169, 19170, 19185, 19186, 19187, 19188, 19203, 19204, 19205, 19206, 19221, 19222, 19223, 19224, 19239, 19240, 19241, 19242, 19257, 19258, 19259, 19260, 19275, 19276, 19277, 19278, 19293, 19294, 19295, 19296, 19311, 19312, 19313, 19314, 19329, 19330, 19331, 19332, 19347, 19348, 19349, 19350, 19365, 19366, 19367, 19368, 19383, 19384, 19385, 19386]


def robustRms(array):
    lq, uq = np.percentile(array, (25.0, 75.0))
    return 0.741*(uq - lq)


def readData(arm="r"):
    butler = Butler(dataDir)
    data = defaultdict(list)
    for vv in visits:
        try:
            wlFitData = butler.get("wlFitData", visit=vv, arm=arm)
        except Exception:
            print(f"Unable to read wlFitData for visit={vv} arm={arm}")
            continue
        for line in wlFitData:
            data[line.refWavelength].append(line)
    return data


def plot(fiberId, wavelength, values, title, cmap=None):
    import matplotlib.pyplot as plt
    import matplotlib.cm
    from matplotlib.colors import Normalize
    from mpl_toolkits.axes_grid1 import make_axes_locatable

    if cmap is None:
        cmap = matplotlib.cm.rainbow

    fig = plt.figure()
    axes = fig.add_subplot(1, 1, 1)
    divider = make_axes_locatable(axes)
    cax = divider.append_axes('right', size='5%', pad=0.05)

    norm = Normalize()
    norm.autoscale(values)
    axes.scatter(fiberId, wavelength, color=cmap(norm(values)))
    axes.set_xlabel("fiberId")
    axes.set_ylabel("Wavelength (nm)")
    axes.set_title(title)
    colors = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
    colors.set_array([])
    fig.colorbar(colors, cax=cax, orientation='vertical')
    plt.show()
    return fig, axes


def main(arm="b"):
    data = readData(arm)
    wavelength = []
    fiberId = []
    xRms = []
    yRms = []
    for wl in sorted(data.keys()):
        for ff in set([dd.fiberId for dd in data[wl]]):
            xx = np.array([dd.xCenter for dd in data[wl] if dd.fiberId == ff])
            yy = np.array([dd.measuredPosition for dd in data[wl] if dd.fiberId == ff])
            good = np.isfinite(xx) & np.isfinite(yy)
            if good.sum() < 50:
                continue
            wavelength.append(wl)
            fiberId.append(ff)
            xRms.append(robustRms(xx[good]))
            yRms.append(robustRms(yy[good]))

    plot(fiberId, wavelength, xRms, "xRms")
    plot(fiberId, wavelength, yRms, "yRms")


if __name__ == "__main__":
    main()
