#!/usr/bin/env python3

import numpy as np
from numpy.fft import fft, fftfreq, fftshift
import matplotlib as mpl
mpl.use('SVG')
from matplotlib.ticker import FuncFormatter
import matplotlib.pyplot as pl
from quantiphy import Quantity
Quantity.set_prefs(map_sf=Quantity.map_sf_to_sci_notation)

# read the input waveforms
vin = np.fromfile('vin.wave', sep=' ')
time, vin = vin.reshape((2, len(vin)//2), order='F')
vout = np.fromfile('vout.wave', sep=' ')
time, vout = vout.reshape((2, len(vout)//2), order='F')
data = np.fromfile('dout.wave', sep=' ')
time, wave = data.reshape((2, len(data)//2), order='F')

# print out basic information about the data
timestep = Quantity(time[1] - time[0], name='Time step', units='s')
nonperiodicity = Quantity(wave[-1] - wave[0], name='Nonperiodicity', units='V')
points = Quantity(len(time), name='Time points')
period = Quantity(timestep * len(time), name='Period', units='s')
freq_res = Quantity(1/period, name='Frequency resolution', units='Hz')
with Quantity.prefs(show_label=True, prec=2):
    print(timestep, nonperiodicity, points, period, freq_res, sep='\n')

# create the window
window = np.kaiser(len(time), 11)/0.37
    # beta=11 corresponds to alpha=3.5 (beta = pi*alpha)
    # the processing gain with alpha=3.5 is 0.37
windowed = window*wave

# transform the data into the frequency domain
spectrum = 2*fftshift(fft(windowed))/len(time)
freq = fftshift(fftfreq(len(wave), timestep))

# define the axis formatting routines
freq_formatter = FuncFormatter(lambda f, p: str(Quantity(f, 'Hz')))
volt_formatter = FuncFormatter(lambda v, p: str(Quantity(v, 'V')))
time_formatter = FuncFormatter(lambda t, p: str(Quantity(t, 's')))

# generate graphs of the resulting spectrum
fig = pl.figure()
ax = fig.add_subplot(111)
ax.plot(freq, np.absolute(spectrum))
ax.set_yscale('log')
ax.xaxis.set_major_formatter(freq_formatter)
ax.yaxis.set_major_formatter(volt_formatter)
pl.savefig('spectrum.svg')
ax.set_xlim((0, 1e6))
ax.set_ylim((1e-7, 1))
pl.savefig('spectrum-zoomed.svg')
fig = pl.figure()
ax = fig.add_subplot(111)
ax.xaxis.set_major_formatter(time_formatter)
ax.plot(time, vout)
ax.plot(time, vin)
pl.savefig('wave.svg')
