#!/usr/bin/env python

import numpy as np
import astropy.io.fits
import matplotlib.pyplot as plt


def readArm(filename):
    with astropy.io.fits.open(filename) as fd:
        wavelength = fd["WAVELENGTH"].data
        flux = fd["FLUX"].data
    assert wavelength.shape == flux.shape
    return wavelength, flux


def calcSignalNoise(wavelength, flux, num=100):
    bounds = np.linspace(600, 1000, num=num + 1)
    signal = np.zeros(num, dtype=float)
    noise = np.zeros(num, dtype=float)
    for ii, (low, high) in enumerate(zip(bounds[:-1], bounds[1:])):
        select = np.logical_and(wavelength > low, wavelength < high)
        if select.sum() == 0:
            continue
        data = flux[select]
        signal[ii] = np.median(data)
        lq, uq = np.percentile(data, [25.0, 75.0])
        noise[ii] = 0.741*(uq - lq)

    center = 0.5*(bounds[1:] + bounds[:-1])
    width = bounds[1:] - bounds[:-1]
    return center, width, signal, noise


def plotStuff(axes, filename, color, scaling=1.0):
    wavelength, flux = readArm(filename)
    flux *= scaling
    center, width, signal, noise = calcSignalNoise(wavelength, flux)
    axes[0].plot(center, signal, marker='o', color=color)
    axes[1].plot(center, noise, marker='o', color=color)
    axes[2].plot(center, signal/noise, marker='o', color=color)
    return wavelength, flux, center, width, signal, noise


def compareArms(filename1, filename2):
    figure, axes = plt.subplots(4)
    data1 = plotStuff(axes, filename1, "red", scaling=200.0)
    data2 = plotStuff(axes, filename2, "blue", scaling=1.0)

    axes[0].set_ylim(0, 150.0)
    axes[1].set_ylim(0, 50.0)

    axes[0].set_title("Signal")
    axes[1].set_title("Noise")
    axes[2].set_title("Signal/Noise")

    axes[3].plot(data1[2], data1[4]/data2[4], 'ko-')
    axes[3].set_ylim(0.5, 1.2)
    ratio = data1[4]/data2[4]
    good = np.isfinite(ratio)
    value = np.median(ratio[good])
    axes[3].axhline(value)
    axes[3].set_title("Signal ratio (median = %f)" % (value,))

    figure.tight_layout()
    plt.show()


def main():
    import sys
    compareArms(sys.argv[1], sys.argv[2])


if __name__ == "__main__":
    main()
