import os
import numpy as np
import astropy.modeling
import matplotlib.pyplot as plt
from lsst.afw.image import ExposureF
from pfs.drp.stella import DetectorMap
from pfs.drp.stella.findAndTraceAperturesTask import FindAndTraceAperturesTask


def trace(exposure):
    detMap = DetectorMap.readFits(os.path.join(os.environ["OBS_PFS_DIR"],
                                  "pfs/camera/detectorMap-sim-r1.fits"))
    config = FindAndTraceAperturesTask.ConfigClass()
#    config.finding.minLength = 1000
    task = FindAndTraceAperturesTask(config=config)
    traces = task.run(exposure.maskedImage, detMap)
    print(len(traces))


def fitWidth(exposure):
    for ii, (lowIndex, highIndex) in enumerate([(245, 275), (455, 485), (1220, 1250), (1500, 1530), (1600, 1630), (1980, 2010), (2110, 2140), (2225, 2255), (2435, 2475), (2635, 2665), (2825, 2855), (3330, 3360), (3630, 3660), (3820, 3850)]):
        indices = np.arange(lowIndex, highIndex)
        array = exposure.image.array[2048, lowIndex:highIndex]
        center = np.argmax(array)
        amplitude = array[center]
        width = 2.5

        lineModel = astropy.modeling.models.Gaussian1D(amplitude, center + lowIndex, width,
                                                       bounds={"mean": (lowIndex, highIndex)},
                                                       name="line")
        bgModel = astropy.modeling.models.Linear1D(0.0, 0.0, name="bg")
        fitter = astropy.modeling.fitting.LevMarLSQFitter()
        fit = fitter(lineModel + bgModel, indices, array)
        print(ii, fit.stddev_0.value)

        if False:
            fig, axes = plt.subplots()
            axes.plot(indices, array, 'k-')
            axes.plot(indices, fit(indices), 'b:')
            axes.text(lowIndex, amplitude, "center = %f\nsigma = %f" % (fit.mean_0.value, fit.stddev_0.value))
            plt.show()


def main():
    exp = ExposureF("postIsrCcd-LA013183r1.fits")
    trace(exp)
#    fitWidth(exp)


if __name__ == "__main__":
    main()
