/****************************************************************************************/
/*                                                                                      */
/* DOT DIFfusion                                                                        */
/*      This function performs the dot diffusion algorithm, as first introduced by      */
/*      Donald Knuth, given a class matrix.                                             */
/*                                                                                      */
/* Synopsis:                                                                            */
/*      Y=dotdif(X,M)                                                                   */
/*              X=Continuous tone original of type double.                              */
/*              M=Class matrix defining the order to which pixels are processed within  */
/*                a halftone cell.                                                      */
/*                                                                                      */
/*              If X is a color image, the class matrix is tiled along the color axis.  */
/*                                                                                      */
/* Dr. Daniel Leo Lau                                                                   */
/* Copyright June 9, 1998                                                               */
/*                                                                                      */
/****************************************************************************************/ 
#include <math.h>
#include "mex.h"
#define Threshold 0
int image_row;
int image_col;
int image_chn;
int mask_row;
int mask_col;
int mask_chn;

typedef struct {  
        int row;
        int col;
        double cost;
} cost_array_node;

/********************************************************************************/
/*                                                                              */
/*                                                                              */
/********************************************************************************/
void hpsort_cost_array(cost_array_node *ra, int n)
{
        int i, ir, j, l;
        cost_array_node rra;

        if (n<2) return;
        l=(n >> 1)+1;
        ir=n;

        for( ; ; ){
                if (l>1){
                        --l;
                        rra=ra[l-1];
                        }
                else {
                        rra=ra[ir-1];
                        ra[ir-1]=ra[0];
                        if (--ir==1){
                                ra[0]=rra;
                                break;
                                }
                        }
                i=l;
                j=l+l;
                while(j<=ir){
                        if (j < ir && ra[j-1].cost < ra[j].cost)j++;
                        if (rra.cost < ra[j-1].cost){
                                ra[i-1]=ra[j-1];
                                i=j;
                                j<<=1;
                                }
                        else j=ir+1;
                        }
                ra[i-1]=rra;
                }
        return;
}

/********************************************************************************/
/*                                                                              */
/*                                                                              */
/********************************************************************************/
double get_neighbor_cost(int Arow, int Acol, int Brow, int Bcol)
{
    if (Arow<0 || Arow>=image_row || Acol<0 || Acol>=image_col) return(0.0);
    if (Brow<0 || Brow>=mask_row || Bcol<0 || Bcol>=mask_col) return(0.0);
    return(1.0);
}

/********************************************************************************/
/*                                                                              */
/*                                                                              */
/********************************************************************************/
void dot_diffusion(double *output_image, double *error_image,
                   double *input_image, double *mask_image)
{
    int m, n, r, s, row, col, cell_row, cell_col;
    int index1, index2, num_valid_neighbors;
    cost_array_node *mask, neighbor_list[8];
    unsigned char *process_image;

    process_image=(unsigned char*)mxCalloc(image_row*image_col, sizeof(unsigned char));
    
    mask=(cost_array_node*)mxCalloc(mask_row*mask_col, sizeof(cost_array_node));
    for (m=0; m < mask_row; m++){
        for (n=0; n < mask_col; n++){
            index1=m+n*mask_row;
            mask[index1].row=m;
            mask[index1].col=n;
            mask[index1].cost=mask_image[index1];
        }
    }
    hpsort_cost_array(mask, mask_row*mask_col);
    
    cell_row=(int)floor((double)image_row/(double)mask_row+1.5);
    cell_col=(int)floor((double)image_col/(double)mask_col+1.5);
    
    for (r=0; r<mask_row*mask_col; r++){
        for (m=0; m<cell_row; m++){
            row=(m*mask_row)+mask[r].row; if (row >= image_row) break;
            for (n=0; n<cell_col; n++){
                col=(n*mask_col)+mask[r].col; if (col >= image_col) break;
                index1=row+col*image_row;

                error_image[index1]+=input_image[index1];
                output_image[index1]=error_image[index1]>=Threshold;
                error_image[index1]=output_image[index1]-error_image[index1];
                process_image[index1]=(unsigned char)1;
                
                num_valid_neighbors=0;
                neighbor_list[0].row=row-1;
                neighbor_list[0].col=col-1;
                neighbor_list[0].cost=get_neighbor_cost(neighbor_list[0].row,neighbor_list[0].col,mask[r].row,mask[r].col);
                if (process_image[neighbor_list[0].row+neighbor_list[0].col*image_row]==(unsigned char)1) neighbor_list[0].cost=0;
                num_valid_neighbors+=(int)(neighbor_list[0].cost>0.5);

                neighbor_list[1].row=row-1;
                neighbor_list[1].col=col-0;
                neighbor_list[1].cost=get_neighbor_cost(neighbor_list[1].row,neighbor_list[1].col,mask[r].row,mask[r].col);
                if (process_image[neighbor_list[1].row+neighbor_list[1].col*image_row]==(unsigned char)1) neighbor_list[1].cost=0;
                num_valid_neighbors+=(int)(neighbor_list[1].cost>0.5);

                neighbor_list[2].row=row-1;
                neighbor_list[2].col=col+1;
                neighbor_list[2].cost=get_neighbor_cost(neighbor_list[2].row,neighbor_list[2].col,mask[r].row,mask[r].col);
                if (process_image[neighbor_list[2].row+neighbor_list[2].col*image_row]==(unsigned char)1) neighbor_list[2].cost=0;
                num_valid_neighbors+=(int)(neighbor_list[2].cost>0.5);

                neighbor_list[3].row=row-0;
                neighbor_list[3].col=col-1;
                neighbor_list[3].cost=get_neighbor_cost(neighbor_list[3].row,neighbor_list[3].col,mask[r].row,mask[r].col);
                if (process_image[neighbor_list[3].row+neighbor_list[3].col*image_row]==(unsigned char)1) neighbor_list[3].cost=0;
                num_valid_neighbors+=(int)(neighbor_list[3].cost>0.5);

                neighbor_list[4].row=row-0;
                neighbor_list[4].col=col+1;
                neighbor_list[4].cost=get_neighbor_cost(neighbor_list[4].row,neighbor_list[4].col,mask[r].row,mask[r].col);
                if (process_image[neighbor_list[4].row+neighbor_list[4].col*image_row]==(unsigned char)1) neighbor_list[4].cost=0;
                num_valid_neighbors+=(int)(neighbor_list[4].cost>0.5);

                neighbor_list[5].row=row+1;
                neighbor_list[5].col=col-1;
                neighbor_list[5].cost=get_neighbor_cost(neighbor_list[5].row,neighbor_list[5].col,mask[r].row,mask[r].col);
                if (process_image[neighbor_list[5].row+neighbor_list[5].col*image_row]==(unsigned char)1) neighbor_list[5].cost=0;
                num_valid_neighbors+=(int)(neighbor_list[5].cost>0.5);

                neighbor_list[6].row=row+1;
                neighbor_list[6].col=col-0;
                neighbor_list[6].cost=get_neighbor_cost(neighbor_list[6].row,neighbor_list[6].col,mask[r].row,mask[r].col);
                if (process_image[neighbor_list[6].row+neighbor_list[6].col*image_row]==(unsigned char)1) neighbor_list[6].cost=0;
                num_valid_neighbors+=(int)(neighbor_list[6].cost>0.5);

                neighbor_list[7].row=row+1;
                neighbor_list[7].col=col+1;
                neighbor_list[7].cost=get_neighbor_cost(neighbor_list[7].row,neighbor_list[7].col,mask[r].row,mask[r].col);
                if (process_image[neighbor_list[7].row+neighbor_list[7].col*image_row]==(unsigned char)1) neighbor_list[7].cost=0;
                num_valid_neighbors+=(int)(neighbor_list[7].cost>0.5);
                
                for (s=0; s<8; s++){
                    if (num_valid_neighbors==0) break;
                    if (neighbor_list[s].cost==1){
                        index2=neighbor_list[s].row+neighbor_list[s].col*image_row;
                        error_image[index2]-=error_image[index1]/(double)num_valid_neighbors;
                    }
                }
            }
        }
    }               
    
    mxFree(process_image);
    return;
}

/********************************************************************************/
/*                                                                              */
/*                                                                              */
/********************************************************************************/
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    unsigned char *output_data;
    double *output_image, *input_image, *error_image, *mask_image;
    int m, number_of_dims, output_dims[3];
    const int *dim_array;

    /****** Check for errors in user's call to function. ***********************/
    if (nrhs!=2)
        mexErrMsgTxt("DOT DIFfusion requires exactly two input argument!");
    else if (nlhs>2)
        mexErrMsgTxt("DOT DIFfusion returns one or two output argument!");
    else if (!mxIsNumeric(prhs[0]) ||
              mxIsComplex(prhs[0]) ||
              mxIsSparse(prhs[0]) ||
             !mxIsDouble(prhs[0]))
        mexErrMsgTxt("Input X must be a real matrix of type double!");
    else if (!mxIsNumeric(prhs[1]) ||
              mxIsComplex(prhs[1]) ||
              mxIsSparse(prhs[1]) ||
             !mxIsDouble(prhs[1]))
        mexErrMsgTxt("Input MASK must be a real matrix of type double!");

    /****** Get row and column sizes for input image ********/
    number_of_dims=mxGetNumberOfDimensions(prhs[0]);
    if (number_of_dims<2 || number_of_dims>3)
        mexErrMsgTxt("Input image X must be 2 or 3 dimensional!");
    else if (number_of_dims==2){
        image_row=mxGetM(prhs[0]);
        image_col=mxGetN(prhs[0]);
        image_chn=1;
    }
    else{
        dim_array=mxGetDimensions(prhs[0]);
        image_row=dim_array[0];
        image_col=dim_array[1];
        image_chn=dim_array[2];
    }
    input_image=mxGetPr(prhs[0]);

    /****** Get row and column sizes for mask ********/
    number_of_dims=mxGetNumberOfDimensions(prhs[1]);
    if (number_of_dims<2 || number_of_dims>3)
        mexErrMsgTxt("Input image MASK must be 2 or 3 dimensional!");
    else if (number_of_dims==2){
        mask_row=mxGetM(prhs[1]);
        mask_col=mxGetN(prhs[1]);
        mask_chn=1;
    }
    else{
        dim_array=mxGetDimensions(prhs[1]);
        mask_row=dim_array[0];
        mask_col=dim_array[1];
        mask_chn=dim_array[2];
    }
    mask_image=mxGetPr(prhs[1]);
    
    /****** Create output variables and extract pointers to data. **************/
    output_dims[0]=image_row;
    output_dims[1]=image_col;
    output_dims[2]=image_chn;
    if (output_dims[2]==1){
        plhs[0]=mxCreateLogicalArray(2, output_dims);
        }
    else{
        plhs[0]=mxCreateLogicalArray(3, output_dims);
        }
    output_data=mxGetLogicals(plhs[0]);
    output_image=(double*)mxCalloc(image_row*image_col*image_chn, sizeof(double));

    if (nlhs==2){
        if (image_chn==1){
            plhs[1]=mxCreateDoubleMatrix(image_row, image_col, mxREAL);
        }
        else{
            plhs[1]=mxCreateNumericArray(3, output_dims, mxDOUBLE_CLASS, mxREAL);
        }
        error_image=mxGetPr(plhs[1]);
    }
    else error_image=mxCalloc(image_row*image_col*image_chn, sizeof(double));

    for (m=0; m<image_chn; m++){
        dot_diffusion(&output_image[m*image_row*image_col], &error_image[m*image_row*image_col],
                      &input_image[m*image_row*image_col], &mask_image[(m%mask_chn)*mask_row*mask_col]);
    }

    for (m=0; m<image_row*image_col*image_chn; m++) output_data[m]=(unsigned char)output_image[m];
    if (nrhs==1) mxFree(error_image);
    mxFree(output_image);
    return;
}

