# -*- coding: utf-8 -*-
"""A (WIP) library to help you with all your GMM need in JAX

Features include creating gmm objects, sampling from them, fitting them to
data, making pretty plots with them, and assimilating observations from linear
observation models. All non-plotting routines are wrapped in jit for speed.

TODO:
	* Implement jvp/vjp for fit_gmm to allow autodiff
	* Implement Jing's Inflation and Tapering. Inflation most likely not
	relevant here
	* Don't implement anything to do with dynamical systems in here. Have a
	clear seperation from dynamical systems and filtering and smoothing with
	them.
	* Create a new module to implement filtering and smoothing on arbitrary
	dynamical systems using DO
	* Add BIC criteria
"""

import jax
import jax.numpy as jnp
import jax.scipy as jsc
import matplotlib.pyplot as plt
import functools
from jax.tree_util import register_pytree_node_class
import matplotlib.patches as mpatches

@register_pytree_node_class
class GMM(object):
	"""GMM. The parent GMM class which will create the objects we need
	"""

	def __init__(self,Π,μ,Σ):
		"""__init__. Constructor to initialize the object

		Parameters
		----------
		Π : jax.ndarray
		Array of weights of size [M] (number of gaussian mixtures)
		μ : jax.ndarray
		Array of means of size [M, d] (d = num_dimensions)
		Σ : jax.ndarray
		Array of covariances of size [M, d, d]
		"""
		self.weights = Π/jnp.sum(Π)
		self.means = μ 
		self.covs = Σ
		self.n = self.weights.shape[0]
		self.d = self.means.shape[1]
	
	def __repr__(self):
		"""__repr__. Pretty printing
		"""
		return "\n GMM with M={} \n Weights: \n{} \n Means: \n{} \n Covariances: \n{}\n".format(
			self.n, self.weights, self.means, self.covs)
		
	def tree_flatten(self):
		"""tree_flatten. To allow jax to treat this objext as a pytree
		"""
		return ([self.weights, self.means, self.covs],None)
	
	@classmethod
	def tree_unflatten(cls,aux_data,params):
		"""tree_unflatten. To allow jax to treat this objext as a pytree

		Parameters
		----------
		aux_data :
			From the flattened pytree
		params :
			From the flattened pytree
		"""
		return cls(params[0],params[1],params[2])
	
	@functools.partial(jax.jit, static_argnums=(0,2))
	def get_samples(self, key, n_mc: int = 1):
		"""get_samples. Sample from the GMM distribution

		Parameters
		----------
		key : jaxlib.xla_extension.DeviceArray
			JAX prng key. WARNING: read about how jax handles generation of random
			numbers. Split a key to use with this function
		n_mc : int
			number of ensemble members to be generated

		Returns:
		--------
		samples : jax.ndarray
			Samples in an array of size [n_mc, d]
		
		"""
		cat_key, multi_key = jax.random.split(key)
		latent_variables = jax.random.categorical(cat_key,
				jnp.log(self.weights), shape=[n_mc])
		latent_one_hot = jnp.expand_dims(
				jax.nn.one_hot(jnp.array(latent_variables), self.n),-1)
		all_samples = jax.random.multivariate_normal(multi_key, self.means,
				self.covs, shape=[n_mc,self.n])
		return jnp.sum(latent_one_hot*all_samples,axis=1)
	
	@jax.jit
	def get_pdf(self,x):
		"""get_pdf. Obtain the value of the pdf of the gmm and the values of the
		individual components. Note that the components are already multiplied by
		the respective weights

		Parameters
		----------
		x : jax.ndarray
			Point(s) at which you wish to evaluate the pdf of the components of the
			mixture and/or the total pdf


		Returns:
		--------
		pdf_value: jax.ndarray
			Total value of the pdf of the gmm at point(s) x
		components: list
			List of values of the individual components of the GMM
		"""
		components = []
		for i,π,μ,σ in zip(range(self.n),self.weights,self.means,self.covs):
			components.append(π*jsc.stats.multivariate_normal.pdf(x,μ,σ))
		return jnp.sum(jnp.array(components),axis=0),components
	
	@jax.jit
	def get_membership_weights(self,x):
		"""get_membership_weights. Given a data point(s), compute the likelihood it
		came from a component

		Parameters
		----------
		x : jnp.ndarray
			Data point(s) whose membership weights we need


		Returns:
		--------
		membership_weights: jax.ndarray of size [M]
			Likelihood this data came from a given component; for each of the
			components
		"""
		likelihood, components = self.get_pdf(x)
		return jnp.array(components)/likelihood
	
	@jax.jit
	def _em_step(self, x):
		"""_em_step. Internal function to perfrom a single EM step

		Parameters
		----------
		x : jnp.ndarray
			Data over which we wish to fit. Size [n_mc. d]

		Returns:
		--------
		new_weights: jax.ndarray
		new_means: jax.ndarray
		new_covs: jax.ndarray
		"""
		wᵢₖ = self.get_membership_weights(x)
		Nₖ = jnp.sum(wᵢₖ,axis=-1)
		weights = Nₖ/x.shape[0]
		means = ((wᵢₖ.dot(x).T)/Nₖ).T
		x_m_mu = jnp.expand_dims(x,1) - jnp.expand_dims(self.means,axis=0)
		covs = jnp.einsum('ki,ikd,ikj->kdj', wᵢₖ, x_m_mu,
				x_m_mu)/jnp.expand_dims(Nₖ,[-1,-2])
		return weights, means, covs
	
	@jax.jit
	def em_step(self, x):
		"""em_step. Wrapper around _em_step to perform a step of EM and assign the
		values to the GMM object.
		TODO: We might be able to merge in _em_step in here without jitting
		problems since we are now a pytree

		Parameters
		----------
		x : jnp.ndarray
			Data over which we wish to fit. Size [n_mc. d]

		Returns:
		--------
		updated_gmm_obj: GMM
		"""
		weights, means, covs = self._em_step(x)
		self.weights = weights
		self.means = means
		self.covs = covs
		return self
	
	@jax.jit
	def get_posterior(self, H, y, R):
		"""get_posterior. Obtain the bayesian posterior given an observation from a
		linear observation model

		Parameters
		----------
		H : jnp.ndarray
			Obs matrix of size [n_obs, d]
		y :
			Observation of size [n_obs]
		R :
			Observation error variance of size [n_obs, n_obs]

		Returns:
		--------
		updated_gmm_obj: GMM
		"""
		x_bar = jnp.sum(self.weights*self.means.T,-1).T
		temp = jnp.linalg.inv(jnp.einsum('sp,jpq,rq->jsr',
			H, jnp.array(self.covs), H) + R)
		K = jnp.einsum('jqp,ip,jir->jqr',jnp.array(self.covs),H,temp)
		
		μ_hat = self.means + jnp.einsum('jik,jk->ji', K, 
									(y-jnp.einsum('ki,ji->jk',H,self.means)))
		
		
		def get_unnorm_weight(weight, mean, cov):
			return weight*jsc.stats.multivariate_normal.pdf(y,
														H.dot(mean),
														H.dot(cov.dot(H.T))+R)
		
		weights_a = jax.vmap(get_unnorm_weight, in_axes=(0,0,0),
						out_axes=0)(self.weights, self.means, self.covs)
		weights_a = weights_a/jnp.sum(weights_a)
	
		means_a = μ_hat
		covs_a = jnp.einsum('jde,jef->jdf',jnp.eye(self.d) - K.dot(H),self.covs)
		
		return GMM(weights_a, means_a, covs_a)
	
	@jax.jit
	def demean(self):
		"""demean. Make the GMM 0 mean

		Returns:
		--------
		updated_GMM: GMM
		with 0 mean
		mean: jnp.ndarray
		The removed mean of size [d]
		"""
		mean = jnp.sum(self.weights*self.means.T,-1).T
		self.means = self.means - mean
		return self, mean
	
	@jax.jit
	def fit_gmm(self,x, thresh=1e-6):
		"""fit_gmm. Fit the GMM to given data using the EM algorithm
		TODO: Implement max iterations

		Parameters
		----------
		x : jnp.ndarray
			Data of size [n_mc, d]
		thresh : float
			Maximium allowable L2 error between the mean of the GMMs in subsequent
			iterations of the EM step

		Returns:
		--------
		updated_gmm_obj: GMM
		Now fitted to the data X
		num_iter: int
		Number of iterations required
		"""
		
		def sub_step(iteration_packet):
			gmm_obj = iteration_packet[0]
			iteration_packet[1] = gmm_obj.means
			gmm_obj = gmm_obj.em_step(x)
			iteration_packet[2] = iteration_packet[2]+1
			iteration_packet[0] = gmm_obj
			return iteration_packet
		
		def cond_func(iteration_packet):
			gmm_obj = iteration_packet[0]
			old_mean = iteration_packet[1]
			return jnp.linalg.norm(gmm_obj.means - old_mean)>thresh
		
		#Iteration packet: [GMM object, old means vector, num_iter]
		init_iteration_packet = [self, self.means-1, 0]
		ret_iteration_packet = jax.lax.while_loop(cond_func, sub_step,
				init_iteration_packet)
		fit_gmm, old_means, num_iter = (*ret_iteration_packet,)
		return fit_gmm, num_iter
	
	def _plot_uni_pdf(self, ax, var_idx, left_plot=True, ls='-', lw=2):
		"""_plot_uni_pdf. Internal plotting utility
		"""
		ax.axis("off")
		weights = self.weights
		means = self.means[:,var_idx]
		stds = jnp.sqrt(self.covs[:,var_idx,var_idx])
		var_range = [jnp.min(means-2*stds),jnp.max(means+2*stds)]
		y = jnp.linspace(*var_range,100)
		
		total_z = 0
		for k in range(self.n):
			pi = weights[k]
			mean = means[k]
			std = stds[k]
			z = pi*jax.scipy.stats.norm.pdf(y,mean,std)
			total_z = total_z + z
			if left_plot:
				ax.plot(z, y, 
						color='C{}'.format(k), ls=ls, lw=lw)
			else:
				ax.plot(y, z,
						color='C{}'.format(k), ls=ls, lw=lw)
			
		if left_plot:
			ax.plot(total_z, y, color='k', lw=lw, ls=ls)
			ax.set_ylim(*var_range)
		else:
			ax.plot(y, total_z, color='k', lw=lw, ls=ls)
			ax.set_xlim(*var_range)
			
			
		return ax, jnp.max(total_z)
	
	def _plot_joint_pdf(self, ax, top_var, left_var, ls='-', lw=2):
		"""_plot_joint_pdf. Internal plotting utility
		"""
		means = self.means[:,[top_var, left_var]]
		covs = self.covs[:,[top_var,left_var],:][:,:,[top_var,left_var]]
		for k in range(self.n):
			ax.plot(means[k,0], means[k,1],'*',color='C{}'.format(k),mec='k')
			w, v = jnp.linalg.eigh(covs[k,:,:])
			w = jnp.sqrt(w)
			conf_interval = mpatches.Ellipse(xy=[means[k,0], means[k,1]],
								width=2*w[0],height=2*w[1],
								angle=180*jnp.arctan2(v[1,0],v[0,0])/jnp.pi,
								color='C{}'.format(k),
								alpha = float(self.weights[k]))
			ax.add_patch(conf_interval)

			conf_interval_edge = mpatches.Ellipse(xy=[means[k,0], means[k,1]],
								width=2*w[0],height=2*w[1],
								angle=180*jnp.arctan2(v[1,0],v[0,0])/jnp.pi,
								ec='C{}'.format(k),fc=None,
								fill=False, ls=ls, lw=lw)
			ax.add_patch(conf_interval_edge)
		return ax
		
	def plot(self, left_vars=None, 
					top_vars=None,
					samples=None, 
					fig=None, axs=None, 
					ls='-', lw=2,):
		"""plot. Get pretty plots of your GMM. SEE EXAMPLE NOTEBOOK

		Parameters
		----------
		left_vars : list
			Indices of variables that show up on the left
		top_vars : list
			Indices of variables that show up on the top
		samples : jnp.ndarray (optional)
			Data to plot with the GMM. Size [n_mc, d]. WARNING: do not include a
			large number of datapoints. Truncate it at around 1000 data points to
			stop matplotlib from crapping out
		fig : matplotlib.pyplot.figure (optional)
		axs : Nested list of matplotlib.pyplot.Axes (optional)
		ls : 
			linestyle (optional)
		lw :
			linewidth (optional)
		"""
		
		if left_vars is None:
			left_vars = jnp.arange(self.d,dtype=int)
		if top_vars is None:
			top_vars = jnp.arange(self.d,dtype=int)
		
		if axs is None:
			fig, axs = plt.subplots(len(left_vars)+1, len(top_vars)+1,
					figsize=[(len(top_vars)+1)*4, (len(left_vars)+1)*4],
					sharex='col', sharey='row', squeeze=False, 
					width_ratios=[3]*len(top_vars) + [1],
					height_ratios=[1]+[3]*len(left_vars))
			fig.subplots_adjust(hspace=0,wspace=0)
		
		axs[0][-1].axis("off")
		
		max_z = 0
		for i, left_var in enumerate(left_vars):
			ax = axs[i+1][-1]
			ax, max_z_i = self._plot_uni_pdf(ax, left_var, left_plot=True, ls=ls, lw=lw)
			max_z = jnp.maximum(max_z, max_z_i)
			axs[i+1][0].set_ylabel(r"$X_{}$".format(left_var), fontsize=16)
			
		for j, top_var in enumerate(top_vars):
			ax=axs[0][j]
			ax, max_z_i = self._plot_uni_pdf(ax, top_var, left_plot=False, ls=ls, lw=lw)
			max_z = jnp.maximum(max_z, max_z_i)
			axs[-1][j].set_xlabel(r"$X_{}$".format(top_var), fontsize=16)
			
		axs[0][-1].set_xlim([0,max_z])
		axs[0][-1].set_ylim([0,max_z])
		
		for i, left_var in enumerate(left_vars):
			for j, top_var in enumerate(top_vars):
				ax = axs[i+1][j]
				ax = self._plot_joint_pdf(ax, top_var, left_var, ls=ls, lw=lw)
				
				if samples is not None:
					ax.plot(samples[:,top_var], samples[:,left_var],'k.',alpha=0.1)
		return fig, axs