////////////////////////////////////////////////////////////////////////
//		THIS FILE IS THE INTERFACE FILE BETWEEN MATLAB AND C  //
////////////////////////////////////////////////////////////////////////

// Authors:   Matt Ueckermann, Pierre Lermusiaux, Pat Haley for MIT course 2.29

////////////////////////////////////
//  DECLARE STANDARD HEADER FILES //
////////////////////////////////////

#include "math.h"
#include "mex.h"
#include <stdlib.h>
#include <string.h>


////////////////////////////
//  MASTER FUNCTION FILE  //
////////////////////////////


//////////////////////////////////////////
// MATLAB INTERFACE FUNCTION DEFINITION //
//////////////////////////////////////////

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
    // Declare input variables
    double dx, dy, *u, *v, *rho;
    int Nx, Ny, *Nodeu, *Nodev, *NodeP, Nbcs, NbcsDP, NbcsDu, NbcsDv;
    
    
    // Declare output variables
    double *Fu;
    
    //Declare working variables
    int i, j;
    double fw, fe, fn, fs, Riws, Rien, V, U;
    
    // Some minor input/output error checking
    if (!(nrhs==14)) {
        mexErrMsgTxt("Need 14 input arguments");
    }
    if (!(nlhs==1)) {
        mexErrMsgTxt("Need 1 outputs arguments");
    }
    
    
    // Get access to input data from MATLAB
    // Pointers are used for large array
    // Scalar are re-created in memory
    
    /********!!!! WARNING !!!************/
    // WARNING !!!!  note that C is expecting to be pointing to type int32 from
    // Matlab.  Therefore, for this code to work, the C argument must be
    // int32.  This is done in matlab by int32(C);
    // The function is then called like [Bnew Cnew] = axB(5,B,int32(C));
    // Use double(Cnew) if you want to convert it back (this might slow things).
    /********!!!! WARNING !!!************/
    //[Fu Fv] = UVadvect_DO_CDS(Nx, Ny, dx, dy, ui, vi, uj, vj, 
    //                          Nodeu, Nodev, idsu, idsv)
    Nx      =   (int)     mxGetScalar( prhs[0]  );
    Ny      =   (int)     mxGetScalar( prhs[1]  );
    dx      =   (double)  mxGetScalar( prhs[2]  );
    dy      =   (double)  mxGetScalar( prhs[3]  );
    rho     =   (double*) mxGetPr(     prhs[4]  );
    u       =   (double*) mxGetPr(     prhs[5]  );
    v       =   (double*) mxGetPr(     prhs[6]  );
    NodeP   =   (int*)    mxGetPr(     prhs[7]  );
    Nodeu   =   (int*)    mxGetPr(     prhs[8]  );
    Nodev   =   (int*)    mxGetPr(     prhs[9]  );
    Nbcs    =   (int)     mxGetScalar( prhs[10] );
    NbcsDP  =   (int)     mxGetScalar( prhs[11] );
    NbcsDu  =   (int)     mxGetScalar( prhs[12] );
    NbcsDv  =   (int)     mxGetScalar( prhs[13] );
    // Example of outputting text to MATLAB
    //     if (time ==0)
    //         mexPrintf("Welcome message for time %f\n",time);
    
    
    // ALLOCATE AND POINT TO OUTPUTS
    plhs[0] =   mxCreateDoubleMatrix((Nx)* Ny   , 1, mxREAL);
    Fu      =   (double*)mxGetPr(plhs[0]);
    
    for(i=1;i<Nx+1;i++) {
        for(j=1;j<Ny+1;j++) {
            //If this is a boundary node, don't do any work
            if (NodeP[j  +(i)*(Ny+2)]<=Nbcs) 
            {
                Fu[j-1+(i-1)*Ny] = 0.0;
            }
            else 
            {  //do work      
                //U advection, quadratic upwind scheme
                //mexPrintf("%d, %f\n",NodeP[j   +(i)*(Ny+2)]-1,dx );
                
                if(NodeP[j  +(i-1)*(Ny+2)]>Nbcs) {
                    //West face
                    Riws  = 0.5*rho[NodeP[j  + i   *(Ny+2)]-1] +\
                            0.5*rho[NodeP[j  +(i-1)*(Ny+2)]-1];                  
                    U     = u[Nodeu[j  +(i-1)*(Ny+2)]-1];
                }
                else //treat boundary condition
                {
                    if (NodeP[j  +(i-1)*(Ny+2)]<=NbcsDP) {
                        //West face
                        Riws = 0.5*(rho[NodeP[j  +(i  )*(Ny+2)]-1]+rho[NodeP[j  +(i-1)*(Ny+2)]-1]);
                        //mexPrintf("West D\n");
                    }
                    else {
                        Riws= (0.5*rho[NodeP[j  +(i-1)*(Ny+2)]-1]\
                                +rho[NodeP[j  +(i)*(Ny+2)]-1]);
                        //mexPrintf("West N\n");
                    }
                    if (Nodeu[j  +(i-1)*(Ny+2)]<=NbcsDu || Nodeu[j  +(i-1)*(Ny+2)]>Nbcs) {
                        //West face
                        U     = u[Nodeu[j  +(i-1)*(Ny+2)]-1];
                    }
                    else {
                        U     = u[Nodeu[j  +(i-1)*(Ny+2)]-1] + u[Nodeu[j  +(i)*(Ny+2)]-1];
                    }
                }

                fw = (1.0/dx)*(Riws*U);
               
                if (NodeP[j  +(i+1)*(Ny+2)]>Nbcs) {
                    //East face
                    Riws  = 0.5*rho[NodeP[j  +(i+1)*(Ny+2)]-1] +\
                            0.5*rho[NodeP[j  +(i  )*(Ny+2)]-1];                            
                    U     = u[Nodeu[j  +(i)*(Ny+2)]-1];
                }
                else //treat boundary condition
                {
                    if (NodeP[j  +(i+1)*(Ny+2)]<=NbcsDP) {
                        //mexPrintf("Dirichlet %d,%d, %f\n",Nodeu[j+(i+1)*(Ny+2)]-1,NbcsDu,uj[Nodeu[j   +(i+1)*(Ny+2)]-1] );
                        //East face
                        Riws = 0.5*(rho[NodeP[j   +(i  )*(Ny+2)]-1] + rho[NodeP[j   +(i+1)*(Ny+2)]-1]);                        
                    }
                    else {
                        Riws = (0.5*rho[NodeP[j   +(i+1)*(Ny+2)]-1] + rho[NodeP[j  +i*(Ny+2)]-1]);
                    }
                    if (Nodeu[j  +(i)*(Ny+2)]<=NbcsDu || Nodeu[j  +(i)*(Ny+2)]>Nbcs) {
                        //West face
                        U     = u[Nodeu[j  +(i)*(Ny+2)]-1];
                    }
                    else {
                        U     = u[Nodeu[j  +(i)*(Ny+2)]-1] + u[Nodeu[j  +(i-1)*(Ny+2)]-1];
                    }
                }
                fe = (1.0/dx)*(Riws*U);

                if (NodeP[j+1+ i*(Ny+2)]>Nbcs) {
                    //South Face
                    Riws =  0.5*rho[NodeP[j  + i*(Ny+2)]-1] +\
                            0.5*rho[NodeP[j+1+ i*(Ny+2)]-1];
                    V    =  v[Nodev[j +(i)*(Ny+1)]-1];
                }
                else { //treat boundaries
                    if (NodeP[j+1+(i)*(Ny+2)]<=NbcsDP) { //Dirichlet on v
                        Riws = 0.5*(rho[NodeP[j+1+ i*(Ny+2)]-1]+rho[NodeP[j  + i*(Ny+2)]-1]);                        
                    }
                    else { //Neumann on u
                        Riws = (0.5*rho[NodeP[j+1+ i*(Ny+2)]-1]+rho[NodeP[j  + i*(Ny+2)]-1]);                        
                    }
                    if (Nodev[j +(i)*(Ny+1)]<=NbcsDv || Nodev[j +(i)*(Ny+1)]>Nbcs) { //Dirichlet on v
                        V    =  v[Nodev[j +(i)*(Ny+1)]-1];
                    }
                    else {
                        V    =  v[Nodev[j-1+(i)*(Ny+1)]-1]+v[Nodev[j +(i)*(Ny+1)]-1];
                    }
                }
                fs = (1.0/dy)*( Riws*V);
                
                //North Face
                if (NodeP[j-1+ i*(Ny+2)]>Nbcs) {
                    Riws =  0.5*rho[NodeP[j-1+ i*(Ny+2)]-1] +\
                            0.5*rho[NodeP[j  + i*(Ny+2)]-1];
                    V    =  v[Nodev[j-1+(i)*(Ny+1)]-1];
                }
                else { //treat boundaries
                    if (NodeP[j-1+(i)*(Ny+2)]<=NbcsDP) { //Dirichlet on u
                        Riws = 0.5*(rho[NodeP[j  + i*(Ny+2)]-1]+rho[NodeP[j-1+ i*(Ny+2)]-1]);                        
                    }
                    else { //Neuman on u
                        Riws = (rho[NodeP[j  + i*(Ny+2)]-1]+0.5*rho[NodeP[j-1+ i*(Ny+2)]-1]);                        
                    }
                    if (Nodev[j-1 +(i)*(Ny+1)]<=NbcsDv || Nodev[j-1 +(i)*(Ny+1)]>Nbcs) { //Dirichlet on v
                        V    =  v[Nodev[j-1 +(i)*(Ny+1)]-1];
                    }
                    else { //Neuman on v
                        V    =  v[Nodev[j-1 +(i)*(Ny+1)]-1]+v[Nodev[j +(i)*(Ny+1)]-1];
                    }
                }
                fn = (1.0/dy)*( Riws*V);
                
                //Calculate Total flux change for this cell
                Fu[j-1+(i-1)*Ny] = fw-fe+fs-fn;              
            }
            
        }
    }
}