# SCRIPTS FFT FFT_travelingWaveΒΆ

# -*-Python-*-
# Created by bgrierson at 27 Feb 2017  10:21

# Basic (and real) FFT
# The basic plane wave has the form:
#  A(x,t) = A0 * cos(k*x-omega*t + phi)
# Where A0 is the amplitude
# k is the wavenumber (2*pi/l, l = wavelength)
# omega is the frequency (2*pi*f)
# phi is the phase
#
# We use two functions y1(t), y2(t) where y2(t) has a positive phase shift.
# In this formulation, the spatial waveform moves to the right in time, and the positive phase shift waveform
# that also moves to the right, is shifted to the left, or backwards in time.
#
# In complex notation this is
# A(x,t) = A0 exp[i(k*x - omega*t + phi)]
# or
# A(x,t) = A0[cos(k*x - omega*t + phi) + i sin(k*x - omega*t + phi)]
#
# In the complex plane we have the (x,y) as Re,Im axes.
# Im
# |
# |
# |______ Re
# And the zero argument is real, i.e. A(0,0) = A0*cos(0) = 1
# The angle is atan(y/x) = atan(Im/Re)
# For a pure cosine the phase is zero.
#
# Let Y1 = FFT(y1) and Y2 = FFT(y2)
# The cross-spectrum is
# Y12 = Y1(c) * Y2 where (c) denotes complex-conjugate.
# The cross-phase is
# phi12 = atan(Im(Y12)/Re(Y12))
#

# Use hanning window?
win = True

# Space
xmin = 0.0
xmax = 1.0
nx = 101
x = linspace(xmin, xmax, nx)
dx = x[1] - x[0]

# Time
tmin = 0.0
tmax = 2.0
nt = 1024
t = linspace(tmin, tmax, nt)
dt = t[1] - t[0]

# Wavelength (-)
lam = 0.25
k = 2.0 * np.pi / lam

# Frequency (Hz)
f = 5.0
omega = 2.0 * np.pi * f

# Phase shift of second signal
phi = np.pi / 4.0

tt, xx = meshgrid(t, x)
y1 = np.cos(k * xx - omega * tt)
y2 = np.cos(k * xx - omega * tt + phi)

if win:
fspec1 = np.fft.rfft(y1[0, :] * np.hanning(nt))
fspec2 = np.fft.rfft(y2[0, :] * np.hanning(nt))
kspec1 = np.fft.rfft(y1[:, 0] * np.hanning(nx))
kspec2 = np.fft.rfft(y2[:, 0] * np.hanning(nx))
else:
fspec1 = np.fft.rfft(y1[0, :])
fspec2 = np.fft.rfft(y2[0, :])
kspec1 = np.fft.rfft(y1[:, 0])
kspec2 = np.fft.rfft(y2[:, 0])

# Frequency and wavenumber axes
fvals = np.fft.rfftfreq(nt, d=dt)
kvals = np.fft.rfftfreq(nx, d=dx) * 2.0 * np.pi

# Phase of each signal
fphase1 = np.arctan2(fspec1.imag, fspec1.real)
fphase2 = np.arctan2(fspec2.imag, fspec2.real)
kphase1 = np.arctan2(kspec1.imag, kspec1.real)
kphase2 = np.arctan2(kspec2.imag, kspec2.real)

# Cross-frequency,wavenumber and cross-phase in time and space
fspec12 = fspec1.conjugate() * fspec2
kspec12 = kspec1.conjugate() * kspec2
fphase12 = np.arctan2(fspec12.imag, fspec12.real)
kphase12 = np.arctan2(kspec12.imag, kspec12.real)

# Array indices at each
whf = np.argmin(np.abs(fvals - f))
whk = np.argmin(np.abs(kvals - k))

print('phase of y1/pi: {}'.format(fphase1[whf] / np.pi))
print('phase of y2/pi: {}'.format(fphase2[whf] / np.pi))
print('phi/pi: {}'.format(phi / np.pi))
print('phi/pi(@f): {}'.format(fphase12[whf] / np.pi))

fig, ax = plt.subplots(nrows=4, ncols=2, figsize=(12, 8))
ax[0, 0].plot(t, y1[0, :], label='y1(t)')
ax[0, 0].plot(t, y2[0, :], label='y2(t)')
if win:
ax[0, 0].plot(t, y1[0, :] * np.hanning(nt), label='y1(win)')
ax[0, 0].plot(t, y2[0, :] * np.hanning(nt), label='y1(win)')
ax[0, 0].set_xlabel('Time')
ax[0, 0].set_title('Functions at x=0')
ax[0, 0].legend()

ax[1, 0].plot(fvals, fspec1.real, label='Re(Y1)')
ax[1, 0].plot(fvals, fspec1.imag, label='Im(Y1)')
ax[1, 0].plot(fvals, np.abs(fspec1), label='|Y1|')
ax[1, 0].vlines(f, ax[1, 0].get_ylim()[0], ax[1, 0].get_ylim()[1], linestyle='dashed')
ax[1, 0].set_xscale('symlog')
ax[1, 0].set_yscale('symlog')
ax[1, 0].legend()

ax[2, 0].plot(fvals, fspec2.real, label='Re(Y2)')
ax[2, 0].plot(fvals, fspec2.imag, label='Im(Y2)')
ax[2, 0].plot(fvals, np.abs(fspec2), label='|Y2|')
ax[2, 0].vlines(f, ax[2, 0].get_ylim()[0], ax[2, 0].get_ylim()[1], linestyle='dashed')
ax[2, 0].set_xscale('symlog')
ax[2, 0].set_yscale('symlog')
ax[2, 0].legend()

ax[3, 0].plot(fvals, fphase1 / np.pi, label='$\\phi 1$')
ax[3, 0].plot(fvals, fphase2 / np.pi, label='$\\phi 2$')
ax[3, 0].plot(fvals, fphase12 / np.pi, label='$\\phi 12$')
ax[3, 0].plot(f, -phi / np.pi, color='black', marker='o', label='$-\\phi$')
ax[3, 0].vlines(f, ax[3, 0].get_ylim()[0], ax[3, 0].get_ylim()[1], linestyle='dashed')
ax[3, 0].set_ylabel('$\\phi/\pi$')
ax[3, 0].set_xscale('symlog')
ax[3, 0].set_xlabel('f (Hz)')
ax[3, 0].legend()

ax[0, 1].plot(x, y1[:, 0], label='y1(x)')
ax[0, 1].plot(x, y2[:, 0], label='y2(x)')
ax[0, 1].set_xlabel('x')
ax[0, 1].set_title('Functions at t=0')
ax[0, 1].legend()

ax[1, 1].plot(kvals, kspec1.real, label='Re(Y1)')
ax[1, 1].plot(kvals, kspec1.imag, label='Im(Y1)')
ax[1, 1].plot(kvals, np.abs(kspec1), label='|Y1|')
ax[1, 1].vlines(k, ax[1, 1].get_ylim()[0], ax[1, 1].get_ylim()[1], linestyle='dashed')
ax[1, 1].set_xscale('symlog')
ax[1, 1].set_yscale('symlog')
ax[1, 1].legend()

ax[2, 1].plot(kvals, kspec2.real, label='Re(Y2)')
ax[2, 1].plot(kvals, kspec2.imag, label='Im(Y2)')
ax[2, 1].plot(kvals, np.abs(kspec2), label='|Y2|')
ax[2, 1].vlines(k, ax[2, 1].get_ylim()[0], ax[2, 1].get_ylim()[1], linestyle='dashed')
ax[2, 1].set_xscale('symlog')
ax[2, 1].set_yscale('symlog')
ax[2, 1].legend()

ax[3, 1].plot(kvals, kphase1 / np.pi, label='$\\phi 1$')
ax[3, 1].plot(kvals, kphase2 / np.pi, label='$\\phi 2$')
ax[3, 1].plot(kvals, kphase12 / np.pi, label='$\\phi 12$')
ax[3, 1].plot(k, phi / np.pi, color='black', marker='o', label='$\\phi$')
ax[3, 1].vlines(k, ax[3, 1].get_ylim()[0], ax[3, 1].get_ylim()[1], linestyle='dashed')
ax[3, 1].set_ylabel('$\\phi/\pi$')
ax[3, 1].set_xscale('symlog')
ax[3, 1].set_xlabel('k (rad/m)')
ax[3, 1].legend()

# Get the graphic
fig, ax = plt.subplots()
ax.set_title('Traveling Wave')
# Change the width to allow for the slider below
plt.subplots_adjust(left=0.15, bottom=0.25)
# Make the line object
lines = [plt.plot([], [])[0], plt.plot([], [])[0]]
# Set initial data
lines[0].set_data(x, y1[:, 0])
lines[1].set_data(x, y2[:, 0])
# Set axes to maxima
plt.axis([np.amin(x), np.amax(x), np.amin(y1), np.amax(y1)])
lines[0].set_color('b')
lines[1].set_color('g')
axstime = plt.axes([0.15, 0.1, 0.70, 0.03])
stime = Slider(axstime, 'Time', np.amin(t), np.amax(t), valinit=np.amin(t), valfmt='%1.3f')

def update(tval):
idx = np.searchsorted(t, tval)
lines[0].set_data(x, y1[:, idx])
lines[1].set_data(x, y2[:, idx])
fig.canvas.draw()

stime.on_changed(update)
plt.show()