#include <stdio.h>
#include <stdlib.h>
#include <R.h>
#include <R_ext/BLAS.h>
#include <R_ext/Lapack.h>
#include "clp.h"
#include "clpmisc.h"

/*
    m, n: positive integers with m >= n
    A: m by n matrix
    tau: n-length vector
*/
CLP_INT compute_qr(const CLP_INT m, const CLP_INT n, double *A, double *tau)
{
    INFO_CLP info = SUCCESS;
    int info0 = 0;
    CLP_INT lwork = -1;
    double work_query;
    double *work = NULL;
    F77_NAME(dgeqrf)(&m, &n, A, &m, tau, &work_query, &lwork, &info0);
    if (info0 != 0)
    {
        info = ERROR_QR;
        goto EXCEPTION;
    }

    lwork = (CLP_INT)work_query;
    work = (double*) CLP_MALLOC(sizeof(double) * lwork);
    CHECKNULL2(work);
    F77_NAME(dgeqrf)(&m, &n, A, &m, tau, work, &lwork, &info0);
    if (info0 != 0)
    {
        info = ERROR_QR;
        goto EXCEPTION;
    }

EXCEPTION:
    CLP_FREE(work);
    return info;
}

/*
    m, n: positive integers with m >= n
    A: m by n matrix
    tau: n-length vector
    x: m-length vector
*/
CLP_INT compute_Qx(const char trans, const CLP_INT m, const CLP_INT n,
    const double *A, const double *tau, double *x)
{
    INFO_CLP info = SUCCESS;
    CLP_INT info1 = 0;
    CLP_INT lwork = -1;
    char side = 'L';
    // char trans = 'N';
    double work_query;
    double *work = NULL;
    CLP_INT n1 = 1;
    F77_NAME(dormqr)(&side, &trans, &m, &n1, &n, A, &m, tau, x, &m, 
        &work_query, &lwork, &info1 FCONE FCONE);
    if (info1 != 0)
    {
        info = ERROR_QX;
        goto EXCEPTION;
    }
    lwork = (CLP_INT)work_query;
    work = (double*) CLP_MALLOC(sizeof(double) * lwork);
    CHECKNULL2(work);

    F77_NAME(dormqr)(&side, &trans, &m, &n1, &n, A, &m, tau, x, &m, 
        work, &lwork, &info1 FCONE FCONE);
    if (info1 != 0)
    {
        info = ERROR_QX;
        goto EXCEPTION;
    }

EXCEPTION:
    CLP_FREE(work);
    return info;
}

/*
    m, n: positive integers with m >=n
    R: n by n upper triangular matrix, embedded in m by n matrix
    x: length n vector
*/
void solve_Rx(const char trans, const CLP_INT m, const CLP_INT n,
    const double *R, double *x)
{
    char uplo = 'U';
    // char trans = 'N';
    char diag = 'N';
    CLP_INT incx = 1;

    F77_NAME(dtrsv)(&uplo, &trans, &diag, &n, R, &m, x, &incx FCONE FCONE FCONE);
}

/*
    At: n by m matrix, containing QR data (Q and R)
    tau: n vector generated by QR factorization
    RinvtRp: m-length vector
    Rd, Rc: n-length vectors
*/
CLP_INT solve_normalEquation(const CLP_INT n, const CLP_INT m, const double *At,
    const double *tau, const double *RinvtRp, const double *Rd, 
    const double *Rc, double *x, double *s, double *y)
{
    INFO_CLP info = SUCCESS;
    // use s as Qt*(Rd - Rc)
    for (size_t i=0; i<n; ++i)
    {
        s[i] = Rd[i] - Rc[i];
    }
// CLP_PRINTF("solvenq\n");
// printvec(n, s);
    info = compute_Qx('T', n, m, At, tau, s);
// printvec(n, s);
    CHECKINFO(info);

    // use x as Qtx
    for (size_t i=0; i<m; ++i)
    {
        x[i] = RinvtRp[i];
    }
    for (size_t i=m; i<n; ++i)
    {
        x[i] = -s[i];
    }

    // Solve for y
    for (size_t i=0; i<m; ++i)
    {
        y[i] = s[i] + RinvtRp[i];
    }
    solve_Rx('N', n, m, At, y);
// printvec(m, y);
    // Solve for x
    info = compute_Qx('N', n, m, At, tau, x);
// printvec(n, x);
    CHECKINFO(info);
    // Solve for s
    for (size_t i=0; i<n; ++i)
    {
        s[i] = Rc[i] - x[i];
    }

EXCEPTION:
    return info;
}   

/*
    Compute a minimum eigenvalue of symmetric matrix S
*/
CLP_INT compute_minEig(const CLP_INT n, double *S, double *w, CLP_INT *m)
{
    INFO_CLP info = SUCCESS;
    char jobz = 'N';
    char range = 'I';
    char uplo ='U';
    double abstol = -1.0;
    double *work = NULL;
    CLP_INT lwork = -1;
    double work_query;
    CLP_INT *iwork = NULL;
    CLP_INT liwork = -1;
    CLP_INT iwork_query;
    CLP_INT info1;
    CLP_INT il = 1;
    CLP_INT iu = 1;
    double vl = 0.0;
    double vu = 0.0;

    F77_NAME(dsyevr)(&jobz, &range, &uplo,
		 &n, S, &n,
		 &vl, &vu,
		 &il, &iu,
		 &abstol, m, w,
		 NULL, &n, NULL,
		 &work_query, &lwork,
		 &iwork_query, &liwork,
		 &info1 FCONE FCONE FCONE);
    if (info1 != 0)
    {   // CLP_PRINTF("mineig0 %d\n", info1);
        info = ERROR_EIG;
        goto EXCEPTION;
    }
    lwork = (CLP_INT)work_query;
    work = (double*) CLP_MALLOC(sizeof(double)*lwork);
    CHECKNULL2(work);
    liwork = iwork_query;
    iwork = (CLP_INT*) CLP_MALLOC(sizeof(CLP_INT)*iwork_query);
    CHECKNULL2(iwork);

    F77_NAME(dsyevr)(&jobz, &range, &uplo,
		 &n, S, &n,
		 &vl, &vu,
		 &il, &iu,
		 &abstol, m, w,
		 NULL, &n, NULL,
		 work, &lwork,
		 iwork, &liwork,
		 &info1 FCONE FCONE FCONE);
    if (info1 != 0)
    {   //CLP_PRINTF("mineig1 %d\n", info1);
        info = ERROR_EIG;
        goto EXCEPTION;
    }

EXCEPTION:
    CLP_FREE(work);
    CLP_FREE(iwork);
    return info;
}

CLP_INT compute_svd(const CLP_INT n, double *A, double *s)
{
    INFO_CLP info = SUCCESS;
    char jobu = 'N';
    char jobvt = 'O';
    double *Vt = NULL;
    double *work = NULL;
    double work_query;
    CLP_INT lwork = -1;
    CLP_INT info1 = 0;

    F77_NAME(dgesvd)(&jobu, &jobvt, &n,
		 &n, A, &n, s,
		 NULL, &n, Vt, &n,
		 &work_query, &lwork, &info1 FCONE FCONE);
    if (info1 != 0)
    {
        info = ERROR_SVD;
        goto EXCEPTION;
    }
    lwork = (CLP_INT)work_query;
    work = (double*) CLP_MALLOC(sizeof(double)*lwork);
    CHECKNULL2(work);
    F77_NAME(dgesvd)(&jobu, &jobvt, &n,
		 &n, A, &n, s,
		 NULL, &n, Vt, &n,
		 work, &lwork, &info1 FCONE FCONE);
    if (info1 != 0)
    {
        info = ERROR_SVD;
        goto EXCEPTION;
    }

EXCEPTION:
    CLP_FREE(work);
    return info;
}

void copy_mat(const char mode, const CLP_INT m, const CLP_INT n, 
    const double *x, const CLP_INT ldx, double *z, const CLP_INT ldz)
{
    if (mode == 'A')
    {
        for (size_t j=0; j<n; ++j)
        {
            for (size_t i=0; i<m; ++i)
            {
                z[i+j*ldz] = x[i+j*ldx];
            }
        }
    }
    else if (mode == 'L')
    {
        for (size_t j=0; j<n; ++j)
        {
            for (size_t i=j; i<m; ++i)
            {
                z[i+j*ldz] = x[i+j*ldx];
            }
        }
    }
    else if (mode == 'U')
    {
        for (size_t j=0; j<n; ++j)
        {
            for (size_t i=0; i<=j; ++i)
            {
                z[i+j*ldz] = x[i+j*ldx];
            }
        }
    }
    else if (mode == 'D')
    {
        for (size_t j=0; j<n; ++j)
        {
            z[j+j*ldz] = x[j+j*ldx];
        }
    }
}

/*
    Fill a vector with zeroes.
    n: length of vector
    x: vector
*/
void zerofill_vec(const CLP_INT n, double *x)
{
    for (size_t i=0; i<n; ++i)
    {
        x[i] = 0.0;
    }
}

/*
    Fill a matrix with zeros
    mode: 'A'll / 'L'ower / 'U'pper
    m: # of rows
    n: # of columns
    x: matrix
    ldx: leading dimension of x
*/
void zerofill_mat(const char mode, const CLP_INT m, const CLP_INT n,
    double *x, const CLP_INT ldx)
{
    if (mode == 'A')
    {
        for (size_t j=0; j<n; ++j)
        {
            for (size_t i=0; i<m; ++i)
            {
                x[i+j*ldx] = 0.0;
            }
        }
    }
    else if (mode == 'L')
    {
        for (size_t j=0; j<n; ++j)
        {
            for (size_t i=j+1; i<n; ++i)
            {
                x[i+j*ldx] = 0.0;
            }
        }
    }
    else if (mode == 'U')
    {
        for (size_t j=0; j<n; ++j)
        {
            for (size_t i=0; i<j; ++i)
            {
                x[i+j*ldx] = 0.0;
            }
        }
    }
}

/*
    Compute Diag(d) * x or x * Diag(d)
    side: 'L'eft / 'R'ight
    m: # of rows
    n: # of columns
    x: a matrix
    ldx: leading dimension of x
    d: vector part of diagonal matrix
*/
void mul_diagMat(const char side, const CLP_INT m, const CLP_INT n, double *x,
    const CLP_INT ldx, const double *d)
{
    if (side == 'L')
    {
        for (size_t i=0; i<m; ++i)
        {
            for (size_t j=0; j<n; ++j)
            {
                x[i+j*ldx] *= d[i];
            }
        }
    }
    else if (side == 'R')
    {
        for (size_t j=0; j<n; ++j)
        {
            for (size_t i=0; i<m; ++i)
            {
                x[i+j*ldx] *= d[j];
            }
        }
    }
}

/*
    Cholesky factorization with dpotrf.
    This only looks at upper part and overrides them.
    n: # of rows/columns
    x: PD matrix
    ldx: leading dimensions of x
*/
CLP_INT compute_chol(const CLP_INT n, double *x, const CLP_INT ldx)
{
    INFO_CLP info = SUCCESS;
    CLP_INT info1 = 0;
    const char uplo = 'U';
    F77_NAME(dpotrf)(&uplo, &n, x, &ldx, &info1 FCONE);
    if (info1 != 0)
    {
        info = ERROR_CHOL;
    }
    return info;
}

/*
    Transpose to new matrix
    x: m by n matrix
    z: n by m matrix, transpose of x
*/
void transpose(const CLP_INT m, const CLP_INT n, const double *x, double *z)
{
    for (size_t j=0; j<n; ++j)
    {
        for (size_t i=0; i<m; ++i)
        {
            z[j+i*n] = x[i+j*m];
        }
    }
}

CLP_INT compute_scalingOpNTSDP(const CLP_INT n, const CLP_INT ldx, const double *x,
    const double *s,
    double *d, double *dinv, double *d05inv, double *g, double *ginv)
{
    INFO_CLP info = SUCCESS;
    double *xChol = NULL, *sv = NULL, *sv05inv=NULL;

    xChol = (double*) CLP_MALLOC(sizeof(double)*n*n);
    CHECKNULL2(xChol);
    zerofill_mat('A', n, n, xChol, n);
    zerofill_mat('A', n, n, ginv, n);
    // sChol = (double*) CLP_MALLOC(sizeof(double)*n*n);
    // CHECKNULL2(sChol);
    copy_mat('U', n, n, x, n, xChol, n);
    copy_mat('U', n, n, s, n, ginv, n);

    info = compute_chol(n, xChol, n);
    if (info != SUCCESS)
    {   
CLP_PRINTF("ERROR:%s, %d\n", __FILE__, __LINE__);
printmat(n, n, xChol);
        info = ERROR_CHOLESKY_X;
        goto EXCEPTION;
    }
    info = compute_chol(n, ginv, n);
    if (info != SUCCESS)
    {   
printmat(n,n,ginv);
CLP_PRINTF("ERROR:%s, %d\n", __FILE__, __LINE__);
printmat(n, n, ginv);
        info = ERROR_CHOLESKY_S;
        goto EXCEPTION;
    }

    char side = 'R';
    char uplo = 'U';
    char trans = 'T';
    char diag = 'N';
    double alpha = 1.0;
    F77_NAME(dtrmm)(&side, &uplo, &trans, &diag, &n, &n, &alpha,
        xChol, &n, ginv, &n FCONE FCONE FCONE FCONE);
    sv = (double*) CLP_MALLOC(sizeof(double)*n);
    sv05inv = (double*) CLP_MALLOC(sizeof(double)*n);
    CHECKNULL2(sv);
    CHECKNULL2(sv05inv);
    info = compute_svd(n, ginv, sv);
    CHECKINFO(info);

    zerofill_mat('A', n, n, d, n);
    zerofill_mat('A', n, n, dinv, n);
    zerofill_mat('A', n, n, d05inv, n);
    for (size_t i=0; i<n; ++i)
    {
        d[i+i*n] = sv[i];
        dinv[i+i*n] = 1.0 / sv[i];
        double z = 1.0 / sqrt(sv[i]);
        sv05inv[i] = z;
        d05inv[i+i*n] = z;
        sv[i] = sqrt(sv[i]);
    }
    transpose(n, n, ginv, g);
    // copy_mat('A', n, n, ginv, n, gt, n);
    mul_diagMat('R', n, n, g, n, sv05inv);
    side = 'L';
    uplo = 'U';
    trans = 'T';
    diag = 'N';
    alpha = 1.0;
    F77_NAME(dtrmm)(&side, &uplo, &trans, &diag, &n, &n, &alpha, xChol, &n, g, &n
            FCONE FCONE FCONE FCONE);

    mul_diagMat('L', n, n, ginv, n, sv);
    side = 'R';
    F77_NAME(dtrsm)(&side, &uplo, &trans, &diag, &n, &n, &alpha, xChol, &n, ginv, &n
        FCONE FCONE FCONE FCONE);

EXCEPTION:
    CLP_FREE(xChol);
    // CLP_FREE(sChol);
    CLP_FREE(sv);
    CLP_FREE(sv05inv);
    return info;
}

void scalebackPrimalSDP(const CLP_INT n, const double *xs, const double *g,
    double *z, double *x)
{
    // INFO_CLP info = SUCCESS;
    // double *z = NULL;
    // z = (double*) CLP_MALLOC(sizeof(double)*n*n);
    // CHECKNULL2(z);

    char transa = 'N';
    char transb = 'N';
    double alpha = 1.0;
    double beta = 0.0;
    F77_NAME(dgemm)(&transa, &transb, &n, &n, &n, &alpha, g, &n, xs, &n, &beta, z, &n
            FCONE FCONE);

    transb = 'T';
    F77_NAME(dgemm)(&transa, &transb, &n, &n, &n, &alpha, z, &n, g, &n, &beta, x, &n
        FCONE FCONE);

// EXCEPTION:
//     CLP_FREE(z);
//     return info;
}

void scaleDualSDP(const CLP_INT n, const double *s, const double *g,  double *z,
    double *ss)
{
    // INFO_CLP info = SUCCESS;
    // double *z = NULL;
    // z = (double*) CLP_MALLOC(sizeof(double)*n*n);
    // CHECKNULL2(z);

    char transa = 'T';
    char transb = 'N';
    double alpha = 1.0;
    double beta = 0.0;
    F77_NAME(dgemm)(&transa, &transb, &n, &n, &n, &alpha, g, &n, s, &n, &beta, z, &n
            FCONE FCONE);

    transa = 'N';
    F77_NAME(dgemm)(&transa, &transb, &n, &n, &n, &alpha, z, &n, g, &n, &beta, ss, &n
            FCONE FCONE);

// EXCEPTION:
//     CLP_FREE(z);
//     return info;
}

void scalebackDualSDP(const CLP_INT n, const double *ss, const double *ginv, 
    double *z, double *s)
{
    // INFO_CLP info = SUCCESS;
    // double *z = NULL;
    // z = (double*) CLP_MALLOC(sizeof(double)*n*n);
    // CHECKNULL2(z);

    char transa = 'T';
    char transb = 'N';
    double alpha = 1.0;
    double beta = 0.0;
    F77_NAME(dgemm)(&transa, &transb, &n, &n, &n, &alpha, ginv, &n, ss, &n, &beta, z, &n
            FCONE FCONE);

    transa = 'N';
    F77_NAME(dgemm)(&transa, &transb, &n, &n, &n, &alpha, z, &n, ginv, &n, &beta, s, &n
        FCONE FCONE);

// EXCEPTION:
//     CLP_FREE(z);
//     return info;
}

void compute_quadcorSDP(const CLP_INT n, const double *dx,
    const double *ds, const double *dinv, double *TZ1, double *TZ2, double *qc)
{
    char transa = 'N';
    char transb = 'N';
    double alpha = 0.5;
    double beta = 0.0;
    F77_NAME(dgemm)(&transa, &transb, &n, &n, &n, &alpha, dx, &n, ds, &n, 
        &beta, TZ1, &n FCONE FCONE);
    F77_NAME(dgemm)(&transa, &transb, &n, &n, &n, &alpha, ds, &n, dx, &n,
        &beta, TZ2, &n FCONE FCONE);

    CLP_INT n2 = n*n;
    for (size_t i=0; i<n2; ++i)
    {
        qc[i] = TZ1[i] + TZ2[i];
    }
    CLP_INT incx = 1;
    CLP_INT incy = 1;
    F77_NAME(dcopy)(&n2, qc, &incx, TZ1, &incy);
    F77_NAME(dcopy)(&n2, qc, &incx, TZ2, &incy);

    char side = 'L';
    char uplo = 'U';
    char diag = 'N';
    F77_NAME(dtrmm)(&side, &uplo, &transa, &diag, &n, &n, &alpha, dinv, &n, TZ1, &n
            FCONE FCONE FCONE FCONE);
    side = 'R';
    F77_NAME(dtrmm)(&side, &uplo, &transa, &diag, &n, &n, &alpha, dinv, &n, TZ2, &n
            FCONE FCONE FCONE FCONE);

    for (size_t i=0; i<n2; ++i)
    {
        qc[i] = TZ1[i] + TZ2[i];
    }
}

/*
    x <- x + alpha * dx
*/
void update_varSDP(const CLP_INT n, double *x, double alpha, double *dx)
{
    CLP_INT inc = 1;
    CLP_INT n2 = n*n;
    F77_NAME(daxpy)(&n2, &alpha, dx, &inc, x, &inc);
}

CLP_INT det(const CLP_INT n, double *x, double *val)
{
    INFO_CLP info = SUCCESS;
    CLP_INT info1;
    CLP_INT *ipiv=NULL;
    // CLP_INT incx = 1;

    ipiv = (CLP_INT*) CLP_MALLOC(sizeof(CLP_INT)*n);
    CHECKNULL2(ipiv);
    
    F77_NAME(dgetrf)(&n, &n, x, &n, ipiv, &info1);
    if (info1 != 0)
    {
        info = ERROR_FEASIBILITY_LU;
        goto EXCEPTION;
    }

    double z = 1.0;
    for (size_t j=0; j<n; ++j)
    {
        if ((j+1) != ipiv[j])
        {
            z = -z;
        }
    }
    for (size_t j=0; j<n; ++j)
    {
        z *= x[j+j*n];
    }
    *val = z;
    
EXCEPTION:
    CLP_FREE(ipiv);
    return info;
}


CLP_INT detS(const CLP_INT n, const double *x, double *val)
{
    INFO_CLP info = SUCCESS;
    CLP_INT info1;
    double *TZ=NULL;

    TZ = (double*) CLP_MALLOC(sizeof(double)*n*n);
    CHECKNULL2(TZ);
    memcpy(TZ, x, sizeof(double)*n*n);

    char uplo = 'U';
    F77_NAME(dpotrf)(&uplo, &n, TZ, &n, &info1 FCONE);
    if (info1 != 0)
    {
        info = ERROR_CHOL;
        goto EXCEPTION;
    }

    double z = 1.0;
    for (size_t i=0; i<n; ++i)
    {
        z *= TZ[i+i*n];
    }
    *val = z*z;

EXCEPTION:
    CLP_FREE(TZ);
    return info;
}

void printvec(const CLP_INT n, const double *z)
{
    for (size_t i=0; i<n; ++i)
    {
        CLP_PRINTF("%f, ", z[i]);
    }
    CLP_PRINTF("\n");
}

void printmat(const CLP_INT m, const CLP_INT n, const double *Z)
{
    for (size_t i=0; i<m; ++i)
    {
        for (size_t j=0; j<n; ++j)
        {
            CLP_PRINTF("%f, ", Z[i+j*m]);
        }
        CLP_PRINTF("\n");
    }
    CLP_PRINTF("\n");
}
