from pseudospectral import *
from pylab import *
from matplotlib.animation import FuncAnimation
import threading
from scipy.integrate import solve_ivp

### Parameters for simulation
N = 256    # Number of grid points in each direction
nu = 6e-6  # Coefficient of viscosity;
           # if too small, simulation will be unstable
Ro = 0.1           

Lap, Helm_inv, Filt, J = Operators2D(N, Ro)

# Note: when using a global variable to communicate between threads,
# reading and writing to the global variable should be protected by a
# lock.  In this case, because the writer thread does a simple
# reference assignment and the consumer thread only reads, it is very
# likely that it works without strict locking.  However, for formal
# correctness, the lock must be there.
lock = threading.Lock()  
    
def qg(t, omega_s):
    global frame_s # global variable to communicate across threads
    with lock:     # Ensure thread-safe update of animation frame
        frame_s = omega_s
    psi_s = Helm_inv(omega_s.reshape(N,N))
    return (-J(psi_s, omega_s.reshape(N,N))
             + nu*Lap(omega_s.reshape(N,N))).ravel()

# Integration function to run in a separate thread
def run_solver():
    solve_ivp(
        qg,
        [0, 10000],
        omega_s.ravel(),
    )

def init():
    return [image]

def animate(i): 
    with lock:
        frame = ifft2(frame_s.reshape(N,N)).real
    image.set_array(frame)
    max_abs = abs(frame).max()
    image.set_clim(-max_abs, max_abs)
    return [image]

if __name__ == "__main__":
    # Start with a white-noise-like random field
    omega_s = Filt(fft2(rand(N,N))) 
    frame_s = omega_s.copy()

    # Set up the figure and axis for animation       
    fig, ax = plt.subplots()
    image = ax.imshow(ifft2(frame_s).real,
                      cmap='RdBu',
                      vmin=-1, vmax=1)

    # Run the integration in a separate thread
    solver_thread = threading.Thread(target=run_solver, daemon=True)
    solver_thread.start()

    animation = FuncAnimation(fig, animate,
                              interval=50,
                              init_func=init)
    show()

    # Wait for the solver thread to finish after the animation
    solver_thread.join()

# The initial idea to use threads for running the solver in parallel
# with the animation come from ChatGPT:
# https://chatgpt.com/share/678ec1b1-4fc0-8002-95ad-0737e23f9dd8
# However, ChatGPT does not understand that solve_ivp lacks an event
# callback function.  In our case, however, we can even simplify the
# code by putting the callback into the right-hand-side function.
