# -*- coding: utf-8 -*- """A (WIP) library to help you with all your GMM need in JAX Features include creating gmm objects, sampling from them, fitting them to data, making pretty plots with them, and assimilating observations from linear observation models. All non-plotting routines are wrapped in jit for speed. TODO: * Implement jvp/vjp for fit_gmm to allow autodiff * Implement Jing's Inflation and Tapering. Inflation most likely not relevant here * Don't implement anything to do with dynamical systems in here. Have a clear seperation from dynamical systems and filtering and smoothing with them. * Create a new module to implement filtering and smoothing on arbitrary dynamical systems using DO * Add BIC criteria """ import jax import jax.numpy as jnp import jax.scipy as jsc import matplotlib.pyplot as plt import functools from jax.tree_util import register_pytree_node_class import matplotlib.patches as mpatches @register_pytree_node_class class GMM(object): """GMM. The parent GMM class which will create the objects we need """ def __init__(self,Π,μ,Σ): """__init__. Constructor to initialize the object Parameters ---------- Π : jax.ndarray Array of weights of size [M] (number of gaussian mixtures) μ : jax.ndarray Array of means of size [M, d] (d = num_dimensions) Σ : jax.ndarray Array of covariances of size [M, d, d] """ self.weights = Π/jnp.sum(Π) self.means = μ self.covs = Σ self.n = self.weights.shape[0] self.d = self.means.shape[1] def __repr__(self): """__repr__. Pretty printing """ return "\n GMM with M={} \n Weights: \n{} \n Means: \n{} \n Covariances: \n{}\n".format( self.n, self.weights, self.means, self.covs) def tree_flatten(self): """tree_flatten. To allow jax to treat this objext as a pytree """ return ([self.weights, self.means, self.covs],None) @classmethod def tree_unflatten(cls,aux_data,params): """tree_unflatten. To allow jax to treat this objext as a pytree Parameters ---------- aux_data : From the flattened pytree params : From the flattened pytree """ return cls(params[0],params[1],params[2]) @functools.partial(jax.jit, static_argnums=(0,2)) def get_samples(self, key, n_mc: int = 1): """get_samples. Sample from the GMM distribution Parameters ---------- key : jaxlib.xla_extension.DeviceArray JAX prng key. WARNING: read about how jax handles generation of random numbers. Split a key to use with this function n_mc : int number of ensemble members to be generated Returns: -------- samples : jax.ndarray Samples in an array of size [n_mc, d] """ cat_key, multi_key = jax.random.split(key) latent_variables = jax.random.categorical(cat_key, jnp.log(self.weights), shape=[n_mc]) latent_one_hot = jnp.expand_dims( jax.nn.one_hot(jnp.array(latent_variables), self.n),-1) all_samples = jax.random.multivariate_normal(multi_key, self.means, self.covs, shape=[n_mc,self.n]) return jnp.sum(latent_one_hot*all_samples,axis=1) @jax.jit def get_pdf(self,x): """get_pdf. Obtain the value of the pdf of the gmm and the values of the individual components. Note that the components are already multiplied by the respective weights Parameters ---------- x : jax.ndarray Point(s) at which you wish to evaluate the pdf of the components of the mixture and/or the total pdf Returns: -------- pdf_value: jax.ndarray Total value of the pdf of the gmm at point(s) x components: list List of values of the individual components of the GMM """ components = [] for i,π,μ,σ in zip(range(self.n),self.weights,self.means,self.covs): components.append(π*jsc.stats.multivariate_normal.pdf(x,μ,σ)) return jnp.sum(jnp.array(components),axis=0),components @jax.jit def get_membership_weights(self,x): """get_membership_weights. Given a data point(s), compute the likelihood it came from a component Parameters ---------- x : jnp.ndarray Data point(s) whose membership weights we need Returns: -------- membership_weights: jax.ndarray of size [M] Likelihood this data came from a given component; for each of the components """ likelihood, components = self.get_pdf(x) return jnp.array(components)/likelihood @jax.jit def _em_step(self, x): """_em_step. Internal function to perfrom a single EM step Parameters ---------- x : jnp.ndarray Data over which we wish to fit. Size [n_mc. d] Returns: -------- new_weights: jax.ndarray new_means: jax.ndarray new_covs: jax.ndarray """ wᵢₖ = self.get_membership_weights(x) Nₖ = jnp.sum(wᵢₖ,axis=-1) weights = Nₖ/x.shape[0] means = ((wᵢₖ.dot(x).T)/Nₖ).T x_m_mu = jnp.expand_dims(x,1) - jnp.expand_dims(self.means,axis=0) covs = jnp.einsum('ki,ikd,ikj->kdj', wᵢₖ, x_m_mu, x_m_mu)/jnp.expand_dims(Nₖ,[-1,-2]) return weights, means, covs @jax.jit def em_step(self, x): """em_step. Wrapper around _em_step to perform a step of EM and assign the values to the GMM object. TODO: We might be able to merge in _em_step in here without jitting problems since we are now a pytree Parameters ---------- x : jnp.ndarray Data over which we wish to fit. Size [n_mc. d] Returns: -------- updated_gmm_obj: GMM """ weights, means, covs = self._em_step(x) self.weights = weights self.means = means self.covs = covs return self @jax.jit def get_posterior(self, H, y, R): """get_posterior. Obtain the bayesian posterior given an observation from a linear observation model Parameters ---------- H : jnp.ndarray Obs matrix of size [n_obs, d] y : Observation of size [n_obs] R : Observation error variance of size [n_obs, n_obs] Returns: -------- updated_gmm_obj: GMM """ x_bar = jnp.sum(self.weights*self.means.T,-1).T temp = jnp.linalg.inv(jnp.einsum('sp,jpq,rq->jsr', H, jnp.array(self.covs), H) + R) K = jnp.einsum('jqp,ip,jir->jqr',jnp.array(self.covs),H,temp) μ_hat = self.means + jnp.einsum('jik,jk->ji', K, (y-jnp.einsum('ki,ji->jk',H,self.means))) def get_unnorm_weight(weight, mean, cov): return weight*jsc.stats.multivariate_normal.pdf(y, H.dot(mean), H.dot(cov.dot(H.T))+R) weights_a = jax.vmap(get_unnorm_weight, in_axes=(0,0,0), out_axes=0)(self.weights, self.means, self.covs) weights_a = weights_a/jnp.sum(weights_a) means_a = μ_hat covs_a = jnp.einsum('jde,jef->jdf',jnp.eye(self.d) - K.dot(H),self.covs) return GMM(weights_a, means_a, covs_a) @jax.jit def demean(self): """demean. Make the GMM 0 mean Returns: -------- updated_GMM: GMM with 0 mean mean: jnp.ndarray The removed mean of size [d] """ mean = jnp.sum(self.weights*self.means.T,-1).T self.means = self.means - mean return self, mean @jax.jit def fit_gmm(self,x, thresh=1e-6): """fit_gmm. Fit the GMM to given data using the EM algorithm TODO: Implement max iterations Parameters ---------- x : jnp.ndarray Data of size [n_mc, d] thresh : float Maximium allowable L2 error between the mean of the GMMs in subsequent iterations of the EM step Returns: -------- updated_gmm_obj: GMM Now fitted to the data X num_iter: int Number of iterations required """ def sub_step(iteration_packet): gmm_obj = iteration_packet[0] iteration_packet[1] = gmm_obj.means gmm_obj = gmm_obj.em_step(x) iteration_packet[2] = iteration_packet[2]+1 iteration_packet[0] = gmm_obj return iteration_packet def cond_func(iteration_packet): gmm_obj = iteration_packet[0] old_mean = iteration_packet[1] return jnp.linalg.norm(gmm_obj.means - old_mean)>thresh #Iteration packet: [GMM object, old means vector, num_iter] init_iteration_packet = [self, self.means-1, 0] ret_iteration_packet = jax.lax.while_loop(cond_func, sub_step, init_iteration_packet) fit_gmm, old_means, num_iter = (*ret_iteration_packet,) return fit_gmm, num_iter def _plot_uni_pdf(self, ax, var_idx, left_plot=True, ls='-', lw=2): """_plot_uni_pdf. Internal plotting utility """ ax.axis("off") weights = self.weights means = self.means[:,var_idx] stds = jnp.sqrt(self.covs[:,var_idx,var_idx]) var_range = [jnp.min(means-2*stds),jnp.max(means+2*stds)] y = jnp.linspace(*var_range,100) total_z = 0 for k in range(self.n): pi = weights[k] mean = means[k] std = stds[k] z = pi*jax.scipy.stats.norm.pdf(y,mean,std) total_z = total_z + z if left_plot: ax.plot(z, y, color='C{}'.format(k), ls=ls, lw=lw) else: ax.plot(y, z, color='C{}'.format(k), ls=ls, lw=lw) if left_plot: ax.plot(total_z, y, color='k', lw=lw, ls=ls) ax.set_ylim(*var_range) else: ax.plot(y, total_z, color='k', lw=lw, ls=ls) ax.set_xlim(*var_range) return ax, jnp.max(total_z) def _plot_joint_pdf(self, ax, top_var, left_var, ls='-', lw=2): """_plot_joint_pdf. Internal plotting utility """ means = self.means[:,[top_var, left_var]] covs = self.covs[:,[top_var,left_var],:][:,:,[top_var,left_var]] for k in range(self.n): ax.plot(means[k,0], means[k,1],'*',color='C{}'.format(k),mec='k') w, v = jnp.linalg.eigh(covs[k,:,:]) w = jnp.sqrt(w) conf_interval = mpatches.Ellipse(xy=[means[k,0], means[k,1]], width=2*w[0],height=2*w[1], angle=180*jnp.arctan2(v[1,0],v[0,0])/jnp.pi, color='C{}'.format(k), alpha = float(self.weights[k])) ax.add_patch(conf_interval) conf_interval_edge = mpatches.Ellipse(xy=[means[k,0], means[k,1]], width=2*w[0],height=2*w[1], angle=180*jnp.arctan2(v[1,0],v[0,0])/jnp.pi, ec='C{}'.format(k),fc=None, fill=False, ls=ls, lw=lw) ax.add_patch(conf_interval_edge) return ax def plot(self, left_vars=None, top_vars=None, samples=None, fig=None, axs=None, ls='-', lw=2,): """plot. Get pretty plots of your GMM. SEE EXAMPLE NOTEBOOK Parameters ---------- left_vars : list Indices of variables that show up on the left top_vars : list Indices of variables that show up on the top samples : jnp.ndarray (optional) Data to plot with the GMM. Size [n_mc, d]. WARNING: do not include a large number of datapoints. Truncate it at around 1000 data points to stop matplotlib from crapping out fig : matplotlib.pyplot.figure (optional) axs : Nested list of matplotlib.pyplot.Axes (optional) ls : linestyle (optional) lw : linewidth (optional) """ if left_vars is None: left_vars = jnp.arange(self.d,dtype=int) if top_vars is None: top_vars = jnp.arange(self.d,dtype=int) if axs is None: fig, axs = plt.subplots(len(left_vars)+1, len(top_vars)+1, figsize=[(len(top_vars)+1)*4, (len(left_vars)+1)*4], sharex='col', sharey='row', squeeze=False, width_ratios=[3]*len(top_vars) + [1], height_ratios=[1]+[3]*len(left_vars)) fig.subplots_adjust(hspace=0,wspace=0) axs[0][-1].axis("off") max_z = 0 for i, left_var in enumerate(left_vars): ax = axs[i+1][-1] ax, max_z_i = self._plot_uni_pdf(ax, left_var, left_plot=True, ls=ls, lw=lw) max_z = jnp.maximum(max_z, max_z_i) axs[i+1][0].set_ylabel(r"$X_{}$".format(left_var), fontsize=16) for j, top_var in enumerate(top_vars): ax=axs[0][j] ax, max_z_i = self._plot_uni_pdf(ax, top_var, left_plot=False, ls=ls, lw=lw) max_z = jnp.maximum(max_z, max_z_i) axs[-1][j].set_xlabel(r"$X_{}$".format(top_var), fontsize=16) axs[0][-1].set_xlim([0,max_z]) axs[0][-1].set_ylim([0,max_z]) for i, left_var in enumerate(left_vars): for j, top_var in enumerate(top_vars): ax = axs[i+1][j] ax = self._plot_joint_pdf(ax, top_var, left_var, ls=ls, lw=lw) if samples is not None: ax.plot(samples[:,top_var], samples[:,left_var],'k.',alpha=0.1) return fig, axs