# -*- coding: utf-8 -*-
"""
Created on Sat May  2 22:40:32 2026

@author: Soyuz Kerman
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

#%% Single plot

def calc_refraction(i, n_i, n_r):
    # Snell-Descartes : n_i*sin(i) = n_r*sin(r)
    i_rad = np.deg2rad(i)
    r_rad = np.asin(n_i/n_r * np.sin(i_rad))
    # Will return a RuntimeWarning in the case of internal total reflection.
    # It can be ignored.
    return r_rad


def calc_rays_coords(i, n_i, n_r):
    
    i_rad = np.deg2rad(i)
    r_rad = calc_refraction(i, n_i, n_r)
    
    # Each list contains the coords of 2 points
    # I = [[x1,x2],[y1,y2]]
    I = [[-np.sin(i_rad), 0], [np.cos(i_rad), 0]]   # incoming ray
    R = [[0, np.sin(r_rad)], [0, -np.cos(r_rad)]]   # refracted ray
    Rx = [[np.sin(i_rad), 0], [np.cos(i_rad), 0]]   # reflected ray
    
    return I, R, Rx, r_rad


def plot_refraction(fig, ax, i, n_i, n_r):
    
    I, R, Rx, r_rad = calc_rays_coords(i, n_i, n_r)
    
    ax.set_xlim(-.5,.5)
    ax.set_ylim(-.5,.5)
    ax.set_aspect("equal")  # square figure
    
    l_i, = ax.plot(I[0],I[1],'r',label = "Incoming ray")
    l_r, = ax.plot(R[0],R[1],'r:',label = "Refracted ray")
    l_rx, = ax.plot(Rx[0],Rx[1],'r--',label = "Reflected ray")
    
    ax.plot([-.5, .5],[0, 0],"k")   # boundary between media
    #ax.legend()    # uncomment to display legend
    ax.add_patch(Rectangle((-0.5, -1), 1, 1, 
                           facecolor="lavender", alpha=1)) # color for second medium
    # return line objects to use in interactive version
    return l_i, l_r, l_rx, r_rad


i = 30      # degrees
n_i = 1     # IOR of air
n_r = 1.5   # IOR of glass

fig = plt.figure("Refraction simulator")
ax = fig.add_subplot(111)

plot_refraction(fig, ax, i, n_i, n_r)

#%% Interactive plot

# tutorial : https://www.youtube.com/watch?v=p-xJsc6LSx0
# https://www.geeksforgeeks.org/python/matplotlib-slider-widget/

# DO NOT USE INLINE PLOTS.
# On SPYDER : Tools > Preferences > IPython Console > Plotting > Graphics Backend > Automatic


from matplotlib.widgets import Slider
from matplotlib.widgets import Button
from matplotlib.widgets import TextBox

fig,ax = plt.subplots(num="Refraction simulator")
plt.subplots_adjust(bottom=0.25)
plt.title("Refraction simulator")

# Initial value
i = 30      # degrees
n_i = 1     # index of refraction
n_r = 1.5

# Default slider range
n_i_min = 1
n_i_max = 2 
n_r_min = 1
n_r_max = 4

# Slider dimensions
slider_x = .375
slider_y = .05

# Text box dimensions
box_x = .15
box_y = .375
box_height = .05

l_i, l_r, l_rx, r_rad = plot_refraction(fig, ax, i, n_i, n_r)


# Sliders for angle and IORs

ax_slider_angle = plt.axes([slider_x, slider_y + .1, .3, .03])
slider_angle = Slider(ax_slider_angle, label="Angle $i$", valmin=0, valmax=90, valinit=i)

ax_slider_n_i = plt.axes([slider_x, slider_y + .05, .3, .03])
slider_n_i = Slider(ax_slider_n_i, label="$n_i$", valmin=n_i_min, valmax=n_i_max, valinit=n_i)

ax_slider_n_r = plt.axes([slider_x,.05,.3,.03])
slider_n_r = Slider(ax_slider_n_r, label="$n_r$", valmin=n_r_min, valmax=n_r_max, valinit=n_r)


# Text boxes for the slider range

ax_slider_n_i_max = plt.axes([box_x,box_y+.3,.05,box_height])
box_n_i_max = TextBox(ax_slider_n_i_max, label="$n_i$ max  ", initial = str(n_i_max))

ax_slider_n_i_min = plt.axes([box_x,box_y+.2,.05,box_height])
box_n_i_min = TextBox(ax_slider_n_i_min, label="$n_i$ min  ", initial = str(n_i_min))

ax_slider_n_r_max = plt.axes([box_x,box_y+.1,.05,box_height])
box_n_r_max = TextBox(ax_slider_n_r_max, label="$n_r$ max  ", initial = str(n_r_max))

ax_slider_n_r_min = plt.axes([box_x,box_y,.05,box_height])
box_n_r_min = TextBox(ax_slider_n_r_min, label="$n_r$ min  ", initial = str(n_i_min))


# Reset button

resetax = plt.axes([0.8, .5, 0.1, 0.1])
button = Button(resetax, 'Reset', color='gold', hovercolor='skyblue')


# Annotations

def gen_annot_text(i, r_rad, n_i, n_r):
    r = np.rad2deg(r_rad)
    r = np.round(r,2)
    i = np.round(i,2)
    n_i = np.round(n_i,2)
    n_r = np.round(n_r,2)
    annot_text = f"$i = {i}°$\n$n_i = {n_i}$\n\n\n$r = {r}°$\n$n_r = {n_r}$"
    return annot_text

annot = ax.annotate(gen_annot_text(i, r_rad, n_i, n_r), (-.4, 0),
                    horizontalalignment='left', verticalalignment='center')


# Slider & annotation update function

def update_slider(val):
    # get current slider value
    i_val = slider_angle.val
    n_i_val = slider_n_i.val
    n_r_val = slider_n_r.val
    
    # calculate new values
    I,R,Rx,r_rad = calc_rays_coords(i_val, n_i_val, n_r_val)
    
    # update lines and annotation
    l_i.set_data(I)
    l_r.set_data(R)
    l_rx.set_data(Rx)
    annot.set_text(gen_annot_text(i_val, r_rad, n_i_val, n_r_val))
    
    return None
    
slider_angle.on_changed(update_slider)
slider_n_i.on_changed(update_slider)
slider_n_r.on_changed(update_slider)


# Reset button function

def reset_sliders(event):
    
    slider_angle.reset()
    slider_n_i.reset()
    slider_n_r.reset()

    return None

button.on_clicked(reset_sliders)


# Slider range update function

def update_text(text):
    
    slider_n_i.valmin = float(box_n_i_min.text)
    slider_n_i.valmax = float(box_n_i_max.text)
    slider_n_i.ax.set_xlim(slider_n_i.valmin,slider_n_i.valmax)
    
    slider_n_r.valmin = float(box_n_r_min.text)
    slider_n_r.valmax = float(box_n_r_max.text)
    slider_n_r.ax.set_xlim(slider_n_r.valmin,slider_n_r.valmax)

    return None

box_n_i_max.on_submit(update_text)
box_n_i_min.on_submit(update_text)
box_n_r_max.on_submit(update_text)
box_n_r_min.on_submit(update_text)



