import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import quad
from scipy.signal import lombscargle

plt.style.use(['seaborn-v0_8-darkgrid', 'seaborn-v0_8-poster'])

PERIOD = 5
DURATION = 1
DROP = 0.5

def fn(t, period, duration, drop):
    phase = t % period
    return np.where(phase < duration, 1 - drop, 1)

def get_an(n, period):
    om = 2 * np.pi / period
    if n == 0:
        return quad(lambda t: 1/period * fn(t, period, DURATION, DROP), 0, period)[0]
    else:
        return quad(lambda t: 2/period * fn(t, period, DURATION, DROP) * np.cos(n*om*t), 0, period)[0]

def get_bn(n, period):
    om = 2 * np.pi / period
    if n == 0:
        return 0
    else:
        return quad(lambda t: 2/period * fn(t, period, DURATION, DROP) * np.sin(n*om*t), 0, period)[0]

def fn_approx(t, nmax, period):
    terms = []
    om = 2 * np.pi / period
    for n in range(nmax):
        terms.append(get_an(n, period) * np.cos(n*om*t) + get_bn(n, period) * np.sin(n*om*t))
    return sum(terms)

def prep_fourier_series_figure():
    ts = np.arange(0,20, 0.001)

    plt.figure(figsize=(15,8))
    plt.subplot(221)
    plt.plot(ts,fn(ts, PERIOD, DURATION, DROP), lw=2)
    plt.title('Original Function')
    plt.subplot(222)
    plt.plot(ts,fn_approx(ts,2,PERIOD), color='C1', lw=2)
    plt.title('2 term Fourier Approximation')
    plt.subplot(223)
    plt.plot(ts, fn_approx(ts,5,PERIOD), color='C2', lw=2)
    plt.title('5 term Fourier Approximation')
    plt.xlabel('t (s)')
    plt.subplot(224)
    plt.plot(ts, fn_approx(ts,20,PERIOD), color='C3', lw=2)
    plt.title('20 term Fourier Approximation')
    plt.xlabel('t (s)')
    plt.tight_layout()
    plt.savefig('transit_fourier_series.png', facecolor='#FFFFFF00')
    plt.show()

def prep_fft_figure():
    ts = np.arange(0,100, 0.001)
    signal = fn(ts, PERIOD, DURATION, DROP)
    signal = signal - signal.mean() # Need to center about 0
    freqs = np.fft.rfftfreq(len(ts), 0.001)
    ft = np.fft.rfft(signal)
    power = 1 / len(ts) * np.abs(ft) ** 2

    plt.figure(figsize=(15,6))
    plt.plot(freqs, power)
    plt.xlim([0,6])
    plt.title('Power Spectrum from FFT')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Power')
    plt.savefig('transit_fft.png', facecolor='#FFFFFF00')
    plt.show()

def prep_lsc_figure():
    ts = np.arange(0,100, 0.001)
    obs_cycle = np.sin(2 * np.pi * ts / 0.75)
    propensity = obs_cycle + (1 - 0.3 * 2)
    propensity = np.clip(propensity, 0, 1)
    obs_times = ts[propensity > 0.5]
    ts = np.sort(np.random.choice(obs_times, len(obs_times) // 2, replace=False))

    signal = fn(ts, PERIOD, DURATION, DROP)
    signal = signal - signal.mean() # Need to center about 0
    freqs = np.linspace(1/100, 6, 1000)
    afreqs = 2 * np.pi * freqs
    power = lombscargle(ts, signal, afreqs)

    plt.figure(figsize=(15,6))
    plt.plot(freqs, power)
    plt.xlim([0,6])
    plt.title('Power Spectrum from Lomb-Scargle with Aliasing')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Power')
    plt.savefig('transit_lsc.png', facecolor='#FFFFFF00')
    plt.show()

if __name__ == '__main__':
    # prep_fourier_series_figure()
    # prep_fft_figure()
    prep_lsc_figure()

