import pickle
from glob import glob
from types import SimpleNamespace

import numpy as np
import scipy.optimize
import matplotlib.pyplot as plt

from lsst.afw.math.functionLibrary import PolynomialFunction2D, Chebyshev1Function2D
from lsst.afw.math import LeastSquares

from pfs.drp.stella.arcLine import ArcLineSet


def calculateXiEta(lines, good, params):
#    double const xi = getFiberPitch()*(fiberId - getFiberCenter()) + getSpatialOffset(fiberIndex);
#    double const eta = (wavelength - getWavelengthCenter())/getDispersion() + getSpectralOffset(fiberIndex);
    fiberId = lines.fiberId[good]
    wavelength = lines.wavelength[good]
    spatialOffsets = np.array([params.spatialOffsets.get(ff, 0.0) for ff in fiberId])
    spectralOffsets = np.array([params.spectralOffsets.get(ff, 0.0) for ff in fiberId])
    xi = (fiberId*params.fiberPitch - params.xi0 + spatialOffsets)/params.scale
    eta = (wavelength/params.dispersion - params.eta0 + spectralOffsets)/params.scale
    print(xi.min(), xi.max(), eta.min(), eta.max())
    return xi, eta


def fitParams(lines, good, polyFit, oldParams):
    numLines = good.sum()
    uniqueFibers = sorted(set(lines.fiberId[good]))
    fiberMapping = {ff: ii for ii, ff in enumerate(uniqueFibers)}
    numFibers = len(fiberMapping)
    fiberIndices = np.array([fiberMapping[ff] for ff in lines.fiberId[good]])
    indices = np.arange(numLines, dtype=int)

    # x = pitch*fiberId - xi0 + xiOffset[fiberId] + P(xi,eta)
    # y = wavelength/dispersion - eta0 + etaOffset[fiberId] + Q(xi,eta)

    xi, eta = calculateXiEta(lines, good, oldParams)

    design = np.zeros((numLines, numFibers), dtype=float)
    design[indices, fiberIndices] = 1

    xValue = np.array([polyFit.xPoly(xx, yy) for xx, yy in zip(xi, eta)])
    yValue = np.array([polyFit.yPoly(xx, yy) for xx, yy in zip(xi, eta)])

    xValue += np.where(lines.x[good] < 2048, 0, polyFit.dx)
    yValue += np.where(lines.y[good] < 2048, 0, polyFit.dy)

    # Could just average the "x-xValue" for each fiberId...

    xDesign = design/lines.xErr[:, np.newaxis][good]  # Weighting
    xRhs = (lines.x[good] - xValue)/lines.xErr[good]
    xEqn = LeastSquares.fromDesignMatrix(xDesign, xRhs)
    xSolution = xEqn.getSolution()/oldParams.scale

    yDesign = design/lines.yErr[:, np.newaxis][good]  # Weighting
    yRhs = (lines.y[good] - yValue)/lines.yErr[good]
    yEqn = LeastSquares.fromDesignMatrix(yDesign, yRhs)
    ySolution = yEqn.getSolution()/oldParams.scale

    fiberPitch = oldParams.fiberPitch
    xMean = xSolution.mean()
    xi0 = oldParams.xi0 - xMean
    spatialOffsets = dict(zip(uniqueFibers, xSolution[:numFibers] - xMean))
    dispersion = oldParams.dispersion
    yMean = ySolution.mean()
    eta0 = oldParams.eta0 - yMean
    spectralOffsets = dict(zip(uniqueFibers, ySolution[:numFibers] - yMean))
    params = SimpleNamespace(fiberPitch=fiberPitch, xi0=xi0, spatialOffsets=spatialOffsets,
                             dispersion=dispersion, eta0=eta0, spectralOffsets=spectralOffsets,
                             scale=oldParams.scale)
    print(xi0, eta0, xSolution.min(), xSolution.max(), ySolution.min(), ySolution.max())
    return params


def fitPolynomials(lines, good, params, order=7):
    Polynomial = Chebyshev1Function2D
    useDesignMatrix = False

    numLines = good.sum()
    numTerms = Polynomial.nParametersFromOrder(order)
    poly = Polynomial(np.zeros(numTerms))

    xi, eta = calculateXiEta(lines, good, params)

    design = np.zeros((numLines, numTerms + 1), dtype=float)
    design[:, :-1] = np.array([poly.getDFuncDParameters(xx, yy) for xx, yy in zip(xi, eta)])
    design[:, -1] = np.where(lines.x[good] < 2048, 0.0, 1.0)

    # x first
    xDesign = design/lines.xErr[:, np.newaxis][good]  # Weighting

    if useDesignMatrix:
        xRhs = lines.x[good]/lines.xErr[good]
        xEqn = LeastSquares.fromDesignMatrix(xDesign, xRhs.astype(float), LeastSquares.DIRECT_SVD)
        print(xEqn.getDimension(), xEqn.getRank())
    else:
        xFisher = np.matmul(xDesign.T, xDesign)
        xRhs = np.matmul(xDesign.T, lines.x[good]/lines.xErr[good])
        xEqn = LeastSquares.fromNormalEquations(xFisher, xRhs)

    xSolution = xEqn.getSolution()
    print(xSolution)
    xPoly = Polynomial(xSolution[:-1])

    if useDesignMatrix:
        xFit = np.matmul(xDesign, xSolution)
        xResid = xRhs - xFit
    else:
        xFit = np.array([xPoly(xx, yy) for xx, yy in zip(xi, eta)])
        xFit += np.where(lines.x[good] < 2048, 0, xSolution[-1])
        xResid = (lines.x[good] - xFit)/lines.xErr[good]
    print(np.sum(xResid**2), numLines)

    # plt.scatter(lines.fiberId[good], xResid)
    # plt.xlabel("fiberId)")
    # plt.ylabel("x residual (sigma)")
    # plt.show()

    # Now y
    yDesign = design/lines.yErr[:, np.newaxis][good]  # Weighting

    if useDesignMatrix:
        yRhs = lines.y[good]/lines.yErr[good]
        yEqn = LeastSquares.fromDesignMatrix(yDesign, yRhs.astype(float), LeastSquares.DIRECT_SVD)
    else:
        yFisher = np.matmul(yDesign.T, yDesign)
        yRhs = np.matmul(yDesign.T, lines.y[good]/lines.yErr[good])
        yEqn = LeastSquares.fromNormalEquations(yFisher, yRhs)

    print(yEqn.getDimension(), yEqn.getRank())

    ySolution = yEqn.getSolution()
    print(ySolution)
    yPoly = Polynomial(ySolution[:-1])

    if useDesignMatrix:
        yFit = np.matmul(yDesign, ySolution)
        yResid = yRhs - yFit
    else:
        yFit = np.array([yPoly(xx, yy) for xx, yy in zip(xi, eta)])
        yFit += np.where(lines.x[good] < 2048, 0, ySolution[-1])
        yResid = (lines.y[good] - yFit)/lines.yErr[good]
    print(np.sum(yResid**2), numLines)

    # plt.scatter(lines.wavelength[good], yResid)
    # plt.xlabel("Wavelength (nm)")
    # plt.ylabel("y residual (sigma)")
    # plt.show()

    return SimpleNamespace(xPoly=xPoly, yPoly=yPoly, xResid=xResid, yResid=yResid,
                           dx=xSolution[-1], dy=ySolution[-1])


def main(order=7, iterations=3, threshold=4.0):
    if False:
        with open("arcLines.pkl", "rb") as fd:
            lines = pickle.load(fd)
    else:
        lines = ArcLineSet.empty()
        for ff in glob("weekly/arcLines-*-r1.fits"):
            lines.extend(ArcLineSet.readFits(ff))

    good = ~lines.flag.astype(bool) & np.isfinite(lines.x) & np.isfinite(lines.y)
    good &= np.isfinite(lines.xErr) & np.isfinite(lines.yErr)
    print(good.sum())

    params = SimpleNamespace(
        fiberPitch=6.0,
        xi0=6*325,
        spatialOffsets={},
        dispersion=0.08,
        eta0=10000,
        spectralOffsets={},
        scale=2100,
    )

    for ii in range(iterations):
        polyFit = fitPolynomials(lines, good, params)
        params = fitParams(lines, good, polyFit, params)
        good[good] &= (np.abs(polyFit.xResid) < threshold) & (np.abs(polyFit.yResid) < threshold)
        print(good.sum())
    else:
        polyFit = fitPolynomials(lines, good, params)
        params = fitParams(lines, good, polyFit, params)

    xResid = polyFit.xResid*lines.xErr[good]
    yResid = polyFit.yResid*lines.yErr[good]
    numLines = good.sum()

    print(xResid.std(), yResid.std())
    xWeights = 1.0/lines.xErr[good]**2
    yWeights = 1.0/lines.yErr[good]**2
    xRms = np.sum(xResid*xWeights)/np.sum(xWeights)
    yRms = np.sum(yResid*yWeights)/np.sum(yWeights)
    print(xRms, yRms)

    def xSoften(sysErr):
        xErr2 = lines.xErr[good]**2 + sysErr**2
        chi2 = np.sum(xResid**2/xErr2)
        return chi2/numLines - 1

    def ySoften(sysErr):
        yErr2 = lines.yErr[good]**2 + sysErr**2
        chi2 = np.sum(yResid**2/yErr2)
        return chi2/numLines - 1

    def softenedChi2(sysErr):
        xErr2 = lines.xErr[good]**2 + sysErr**2
        yErr2 = lines.yErr[good]**2 + sysErr**2
        chi2 = np.sum(xResid**2/xErr2 + yResid**2/yErr2)
        return chi2/(2*numLines) - 1

    if softenedChi2(0.0) > 0:
        sysErr = scipy.optimize.bisect(softenedChi2, 0.0, 1.0)
        print(sysErr)
    else:
        print("No softening")

    if xSoften(0.0) > 0:
        xSys = scipy.optimize.bisect(xSoften, 0.0, 1.0)
        print(xSys)
    else:
        print("No x softening")

    if ySoften(0.0) > 0:
        ySys = scipy.optimize.bisect(ySoften, 0.0, 1.0)
        print(ySys)
    else:
        print("No y softening")



if __name__ == "__main__":
    main()
