In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from IPython.display import HTML
import matplotlib.animation as anim
import matplotlib;
# matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
import plotly.io as pio

pio.renderers.default = "browser"

import plotly.graph_objects as go
import hj_reachability as hj
import time
# import netCDF4
from functools import partial

import seaborn as sns
import pandas as pd


import jax.scipy as jsc
import functools
from jax.tree_util import register_pytree_node_class


jcm = jax.default_device(jax.devices()[2])
jcm.__enter__()

In [2]:
#Boundary
def dirichlet(x, pad_width: int):
    return jnp.pad(x, ((pad_width, pad_width)), "constant", constant_values=3.0)


def dg_x(params, arg):
    x, y, t = arg[0], arg[1], arg[2]
    return _dg_x(t,x,y,params)

def dg_y(params, arg):
    x, y, t = arg[0], arg[1], arg[2]
    return _dg_y(t,x,y,params)


def _dg_x(t,x,y,params):
    eps = params[0]
    A = params[1]
    omega = params[2]
    a = eps*jnp.sin(omega*t)
    b = 1-2*a
    f = a*(x**2) + b*x
    return -jnp.pi*A*jnp.sin(jnp.pi*f)*jnp.cos(jnp.pi*y)


def _dg_y(t,x,y,params):
    eps, A, omega = (*params,)
    a = eps*jnp.sin(omega*t)
    b = 1-2*a
    f = a*(x**2) + b*x
    df = 2*(a*x) + b
    return jnp.pi*A*jnp.cos(jnp.pi*f)*jnp.sin(jnp.pi*y)*df


v_dg_x = jax.vmap(_dg_x,(None,0,0,0),0)
v_dg_y = jax.vmap(_dg_y,(None,0,0,0),0)

In [3]:
# #%% Run reachability on the current agumented problem
# # init settings
# target settings

grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(lo=np.array([0., 0.]),
                                                                        hi=np.array([2., 1.])), (101, 51),
                                                              boundary_conditions=(dirichlet, dirichlet))
times = jnp.linspace(0, -2.8, 91)
# params = [1.,1.,0.1]

def get_val_func(x_init,x_target, params, grid=grid, times=times):
    
    initial_values = hj.shapes.shape_sphere(grid=grid, center=x_target, radius = 0.02 )

    def multi_reach_step(mask, time, val):
        val = jnp.where(mask<0,-1,val)
        return val

    p_multi_reach_step = partial(multi_reach_step, initial_values)

    solver_settings = hj.SolverSettings.with_accuracy("high",
                                                      artificial_dissipation_scheme=
                                                          hj.artificial_dissipation.local_local_lax_friedrichs,
                                                        hamiltonian_postprocessor = p_multi_reach_step
                                                      )


    def y_cur(arg):
        return dg_x(params, arg)

    def x_cur(arg):
        return dg_y(params, arg)


    # run the solver
    Plat2D_sim = hj.systems.Platform2Dcurrents(u_max=1.0, control_mode='min', x_current=x_cur, y_current=y_cur)

    start = time.time()
    all_values = hj.solve(solver_settings, Plat2D_sim, grid, times, initial_values, progress_bar=False)
    
    return all_values

In [4]:
from jax import jacfwd, jacrev, jit
_key = jax.random.PRNGKey(1)
key, _key = jax.random.split(_key)

x_init = jnp.array([0.25, 0.6])
x_target = jnp.array([1.6, 0.5])
r=1e-2

def termination_condn(x_target, r, x, t):
    return jnp.linalg.norm(x_target - x)<=r

termination_condn = partial(termination_condn, x_target, r)

get_val_func_params = lambda params: get_val_func(x_init, x_target, params)

In [35]:
true_params = jnp.array([0.1,0.4,0.1])

init_params = jnp.array([0.01,0.01,0.01])

def y_cur(arg):
    return dg_x(true_params, arg)

def x_cur(arg):
    return dg_y(true_params, arg)

# run the solver
Plat2D_true = hj.systems.Platform2Dcurrents(u_max=1.0, control_mode='min', x_current=x_cur, y_current=y_cur)


In [36]:
def traj_solve(params, times=times, val_func=val_func, x_init=x_init, grid=grid):
    def y_cur(arg):
        return dg_x(params, arg)

    def x_cur(arg):
        return dg_y(params, arg)
    
    Plat2D = hj.systems.Platform2Dcurrents(u_max=1.0, control_mode='min', x_current=x_cur, y_current=y_cur)
#     traj_times, x_traj, contr_seq, distr_seq = Plat2D.backtrack_trajectory(grid, x_init, 
#                                                                         times, val_func, 
#                                                                         termination_condn=termination_condn)
    
    x_traj = jnp.zeros((grid.ndim, len(times)-1))
    
    x_final = x_init
    x_traj = x_traj.at[:,0].set(x_final)
    for i,time in enumerate(times[:1:-1]):
        u_opt, d_opt = Plat2D.get_opt_ctrl_from_values(grid, x_final, time, times, val_func)
        dt = times[-i-1] - times[-i-2]
        x_final = x_final - dt*jnp.array(Plat2D(x_final,u_opt,d_opt,time))
        x_traj = x_traj.at[:,i+1].set(x_final)
    return x_traj[::3].ravel()

In [22]:
def traj_solve_debug(params, times=times, val_func=val_func, x_init=x_init, grid=grid):
    def y_cur(arg):
        return dg_x(params, arg)

    def x_cur(arg):
        return dg_y(params, arg)
    
    Plat2D = hj.systems.Platform2Dcurrents(u_max=1.0, control_mode='min', x_current=x_cur, y_current=y_cur)
    traj_times, x_traj, contr_seq, distr_seq = Plat2D.backtrack_trajectory(grid, x_init, 
                                                                        times, val_func, 
                                                                        termination_condn=termination_condn)
    return x_traj[:,0]

In [37]:
val_func = get_val_func_params(init_params)

In [38]:
pos_sim = traj_solve_debug(init_params, times[75:], val_func[75:,:,:])


In [39]:
pos_sim

array([0.71655101, 0.56864899])

In [40]:
jac_pos = jax.jacrev(traj_solve)

In [41]:
pos_sim = traj_solve(init_params, times[75:], val_func[75:,:,:])
pos_true = traj_solve(true_params, times[75:], val_func[75:,:,:])

print("sim pos", pos_sim, "\n true pos", pos_true,"\n")


dpos = (pos_sim-pos_true)
print("\ndpos\n",dpos)

sim pos [0.25       0.28167006 0.31327528 0.34480965 0.37626725 0.40764388
 0.43893528 0.47013906 0.50125265 0.5322755  0.5632069  0.5940481
 0.6248001  0.65546584 0.6860481 ] 
 true pos [0.25       0.30733457 0.35871038 0.40421185 0.44426453 0.47947142
 0.5104927  0.5379776  0.56252766 0.58468497 0.60492814 0.6236783
 0.64130324 0.6581243  0.67441946] 


dpos
 [ 0.         -0.02566451 -0.0454351  -0.0594022  -0.06799728 -0.07182753
 -0.0715574  -0.06783852 -0.06127501 -0.05240947 -0.04172122 -0.02963024
 -0.01650316 -0.00265849  0.01162863]


In [42]:
jac_pos_sim = jac_pos(init_params, times[75:], val_func[75:,:,:])
jac_pos_sim = jac_pos_sim.to_py()

In [64]:
init_params - np.linalg.inv(jac_pos_sim.T.dot(jac_pos_sim) + 0.001*np.eye(3)).dot(jac_pos_sim.T).dot(dpos)

DeviceArray([0.0694989 , 0.19613367, 0.06948611], dtype=float32)

In [45]:
res = np.linalg.lstsq(jac_pos_sim, dpos, rcond=None)

In [60]:
res[0]

array([-6.9366095e+06, -3.8277200e-01,  6.9379715e+06], dtype=float32)

In [46]:
params = init_params - res[0]
print(params)

[ 6.9366095e+06  3.9277199e-01 -6.9379715e+06]


In [17]:
true_params

DeviceArray([0.1, 0.4, 0.1], dtype=float32)

In [None]:
np.linalg.lstsq?

In [None]:
jax.numpy.linalg.lstsq?

In [None]:
traj_times

In [None]:
plt.pcolor(grid.states[...,0], grid.states[...,1], val_func[-1,:,:], vmin=-4, vmax=4, cmap=matplotlib.colormaps['RdBu']); plt.colorbar()

In [None]:
traj_times, x_traj, contr_seq, distr_seq = Plat2D_true.backtrack_trajectory(grid, x_init, 
                                                                            times, val_func, 
                                                                            traj_times=np.linspace(-2.5,-3,50), 
                                                                            termination_condn=termination_condn)

In [None]:
plt.plot(x_traj[0,:],x_traj[1,:])
plt.xlim([0,2])
plt.ylim([0,1])

In [None]:
plt.plot(x_traj[0,:],x_traj[1,:])
plt.xlim([0,2])
plt.ylim([0,1])

In [None]:
traj_times

In [None]:
x_traj.shape

In [None]:
def simulate(params):
    

In [None]:
all_vals.shape

In [None]:
@register_pytree_node_class
class GMM(object):
    def __init__(self,Π,μ,Σ):
        self.weights = Π/jnp.sum(Π)
        self.means = μ 
        self.covs = Σ
        self.n = self.weights.shape[0]
    
    def __repr__(self):
        return "\n GMM with M={} \n Weights: \n{} \n Means: \n{} \n Covariances: \n{}\n".format(
            self.n, self.weights, self.means, self.covs)
        
    #@staticmethod
    def tree_flatten(self):
        return ([self.weights, self.means, self.covs],None)
    
    @classmethod
    def tree_unflatten(cls,aux_data,params):
        return cls(params[0],params[1],params[2])
    
    #@functools.partial(jax.jit, static_argnums=(0,))
    def get_samples(self, key, n_mc: int = 1):
        cat_key, multi_key = jax.random.split(key)
        latent_variables = jax.random.categorical(cat_key, jnp.log(self.weights), shape=[n_mc])
        latent_one_hot = jnp.expand_dims(jax.nn.one_hot(jnp.array(latent_variables), self.n),-1)
        all_samples = jax.random.multivariate_normal(multi_key, self.means, self.covs, shape=[n_mc,self.n])
        return jnp.sum(latent_one_hot*all_samples,axis=1)
    
    #@jax.jit
    def get_pdf(self,x,get_components=False):
        components = []
        for i,π,μ,σ in zip(range(self.n),self.weights,self.means,self.covs):
            components.append(π*jsc.stats.multivariate_normal.pdf(x,μ,σ))
        if get_components:
            return jnp.sum(jnp.array(components),axis=0),components
        else:
            return jnp.sum(jnp.array(components),axis=0)
    
    #@jax.jit
    def get_membership_weights(self,x):
        likelihood, components = self.get_pdf(x,get_components=True)
        return jnp.array(components)/likelihood
    
    #@jax.jit
    def _em_step(self, x):
        wᵢₖ = self.get_membership_weights(x)
        Nₖ = jnp.sum(wᵢₖ,axis=-1)
        weights = Nₖ/x.shape[0]
        means = ((wᵢₖ.dot(x).T)/Nₖ).T
        x_m_mu = jnp.expand_dims(x,1) - jnp.expand_dims(self.means,axis=0)
        covs = jnp.einsum('ki,ikd,ikj->kdj',wᵢₖ,x_m_mu,x_m_mu)/jnp.expand_dims(Nₖ,[-1,-2])
        return weights, means, covs
    
    #@jax.jit
    def em_step(self, x):
        weights, means, covs = self._em_step(x)
        self.weights = weights
        self.means = means
        self.covs = covs
        return self
    
    @jax.jit
    def fit_gmm(self,x, steps=100):
        
        def sub_step(self, z):
            ret_val = self.em_step(x)
            return ret_val, ret_val

        self, hist = jax.lax.scan(sub_step, self, xs=None, length=steps)
        return self

In [None]:
%matplotlib inline

In [None]:
from jax import jacfwd, jacrev, jit
_key = jax.random.PRNGKey(0)
key, _key = jax.random.split(_key)

In [None]:
get_val_func_vel_speed = lambda vel_speed: get_val_func(x_init, x_target, vel_speed[0], vel_speed[1])

v_val = jax.vmap(get_val_func_vel_speed,0,0)

In [None]:
n_mc = 500
vel_speed = jax.random.multivariate_normal(key, jnp.array([0.0,1.0]), jnp.array([[2.0,0],[0,0.05]]), shape=[n_mc])
vel_speed = vel_speed.at[:,1].set(jnp.abs(vel_speed[:,1]));

In [None]:
all_all_values = v_val(vel_speed)

all_all_values = all_all_values[:,-1,...]

In [None]:
all_all_values.shape

In [None]:
d_val = jax.vmap(lambda x:grid.grad_values(jnp.squeeze(x)), 0, 0) 
d_all_all_values = d_val(all_all_values)

In [None]:
d_all_all_values.shape

In [None]:
plt.pcolor(grid.states[...,0], grid.states[...,1],all_all_values[332])
plt.colorbar()

In [None]:
val_p1 = grid.interpolate(all_all_values.transpose([1,2,0]),jnp.array([1.0,1.25]))
val_p2 = grid.interpolate(all_all_values.transpose([1,2,0]),jnp.array([0.8,0.25]))
val_p3 = grid.interpolate(all_all_values.transpose([1,2,0]),jnp.array([0.5,1.25]))

grad_val_p1 = grid.interpolate(d_all_all_values.transpose([1,2,0,3]),jnp.array([1.0,1.25]))
grad_val_p2 = grid.interpolate(d_all_all_values.transpose([1,2,0,3]),jnp.array([0.8,0.25]))
grad_val_p3 = grid.interpolate(d_all_all_values.transpose([1,2,0,3]),jnp.array([0.5,1.25]))


In [None]:
grad_val_p1.shape

In [None]:
controls_test = jnp.reshape(jnp.linspace(0,jnp.pi*2,201)[:-1],[-1,1])

In [None]:
controls_test.shape

In [None]:
val_func_increment = grad_val_p1[:,0]*jnp.cos(controls_test) + grad_val_p1[:,0]*jnp.sin(controls_test)
val_func_increment_p3 = grad_val_p3[:,0]*jnp.cos(controls_test) + grad_val_p3[:,0]*jnp.sin(controls_test)


In [None]:
import scipy.stats as stats

In [None]:
val_func_inc_stats = stats.describe(val_func_increment.T)
val_func_inc_stats_p3 = stats.describe(val_func_increment_p3.T)

In [None]:
plt.plot(controls_test,val_func_inc_stats.mean, label='Val func increment mean')
plt.plot(controls_test,val_func_inc_stats.variance, label='Val func increment variance')
plt.plot(controls_test,val_func_inc_stats.minmax[1], label='Val func increment max')
plt.plot(controls_test,val_func_inc_stats.minmax[0], label='Val func increment min')
plt.grid()
plt.xlim([0,2*np.pi])
plt.legend()

In [None]:
plt.plot(controls_test,val_func_inc_stats_p3.mean, label='Val func increment mean')
plt.plot(controls_test,val_func_inc_stats_p3.variance, label='Val func increment variance')
plt.plot(controls_test,val_func_inc_stats_p3.minmax[1], label='Val func increment max')
plt.plot(controls_test,val_func_inc_stats_p3.minmax[0], label='Val func increment min')
plt.grid()
plt.xlim([0,2*np.pi])
plt.legend()

In [None]:
val_func_inc_stats_p3.mean - val_func_inc_stats.mean

In [None]:
i=np.argmax(-val_func_inc_stats.mean)

sns.histplot(val_func_increment[i,:])
plt.title("{}*pi".format(controls_test[i][0]/jnp.pi))

In [None]:
i=np.argmin(val_func_inc_stats.variance)

sns.histplot(val_func_increment[i,:])
plt.title("{}*pi".format(controls_test[i][0]/jnp.pi))

In [None]:
val_func_increment.shape

In [None]:
data = pd.DataFrame([np.array(val_p1),
                     np.array(val_p2), 
                     np.array(vel_speed[...,0]),
                     np.array(vel_speed[...,1]),
                     np.arange(n_mc)]).T

In [None]:
data.columns=['val_func at [1.0,1.5]','val_func at [0.8,0.25]','hway_vel', 'v_speed', 'id']

In [None]:
jplt = sns.jointplot(data=data,x='val_func at [0.8,0.25]', y='id')
jplt.plot_marginals(sns.histplot,kde=True)

In [None]:
jplt = sns.jointplot(data=data,x='val_func at [0.8,0.25]', y='hway_vel')
jplt.plot_marginals(sns.histplot,kde=True)

In [None]:
jplt = sns.jointplot(data=data,x='val_func at [1.0,1.5]', y='hway_vel')
jplt.plot_marginals(sns.histplot,kde=True)

In [None]:
jplt = sns.jointplot(data=data,x='val_func at [0.8,0.25]', y='v_speed')
jplt.plot_marginals(sns.histplot,kde=True)

In [None]:
jplt = sns.jointplot(data=data,x='val_func at [1.0,1.5]', y='v_speed')
jplt.plot_marginals(sns.histplot,kde=True)

# Single step assimilation

In [None]:
var_shape = all_all_values.shape[1:]
n_mc = all_all_values.shape[0]

all_all_values = all_all_values.reshape([n_mc, np.prod(var_shape)])

In [None]:
#all_all_values = all_all_values.reshape([n_mc, *var_shape])

In [None]:
all_all_values.shape

In [None]:
u,s,vt = jnp.linalg.svd(all_all_values,full_matrices=False)

In [None]:
ur = u[:,:100]
sr = s[:100]

In [None]:
vt.reshape([-1,var_shape])

In [None]:
plt.loglog(s)