/*----------------------------MegaWave2 module----------------------------*/
/* mwcommand
name = {knn_image};
version = {"1.0"};
author = {"Lionel Moisan"};
function = {"Image classification using k-nearest neighbor"};
usage = {
 'k':[k=1]->k       "number of neighbors (k-NN), default 1",
 train->train       "training set (Cmovie)",
 labels->labels     "labels for training set (integers stored in a Fsignal)",
 test->test         "test set (Cmovie)",
 classif<-knn_image "output estimated labels for test set (Fsignal)",
 {
   true->true    "true labels (Fsignal) for test set (performance evaluation)",
   error<-error  "output: error made on test set"
 }
        };
*/
 
#include <stdio.h>
#include <math.h>
#include "mw.h"


double imdist(u,v)
     Cimage u,v;
{
  double d,e;
  int i;
  
  for (d=0.,i=u->nrow*u->ncol;i--;) {
    e = (double)u->gray[i]-(double)v->gray[i];
    d += e*e;
  }

  return(d);
}

void insert(dtab,ctab,d,c,k)
     double *dtab,d;
     int *ctab,c,k;
{
  int i,j;

  for (i=k;i>=1 && (ctab[i-1]==-1 || dtab[i-1]>d);i--);
  if (i<k) {
    for (j=k-1;j>i;j--) {
      dtab[j]=dtab[j-1];
      ctab[j]=ctab[j-1];
    }
    dtab[i]=d;
    ctab[i]=c;
  }
}

/*------------------------------ MAIN MODULE ------------------------------*/

Fsignal knn_image(train,labels,test,true,error,k)
     Cmovie train,test;
     Fsignal labels,true;
     float *error;
     int *k;
{
  Fsignal classif;
  Cimage u,v;
  int i,j,l,nmax,lmax,*class;
  double d,*dist;

  /* size of test set */
  for (u=test->first,j=0;u;u=u->next,j++);
  classif = mw_change_fsignal(NULL,j);
  if (true) *error=0.;
  dist = (double *)malloc(*k*sizeof(double));
  class = (int *)malloc(*k*sizeof(int));

  /* loop on test images */
  for (u=test->first,j=0;u;u=u->next,j++) {

    for (i=0;i<*k;i++) class[i]=-1;

    /* loop on train images */
    for (v=train->first,i=0;v;v=v->next,i++) {
      d = imdist(u,v);
      insert(dist,class,d,(int)labels->values[i],*k);
    }
    /* find dominant label */
    for (i=0;i<*k;i++) dist[i]=1.;
    for (i=*k-1;i>=0;i--) 
      for (l=0;l<i;l++)
	if (class[l]==class[i]) dist[l]+=1.;
    lmax = class[0];
    nmax = (int)dist[0];
    for (i=1;i<*k;i++)
      if ((int)dist[i]>nmax) { nmax=(int)dist[i]; lmax=class[i];}

    /* store result */
    classif->values[j] = (float)lmax;
    if (true) *error += (classif->values[j]!=true->values[j]?1.:0.);

  }

  if (true) *error/=(float)j;  
  free(class); free(dist);

  return(classif);
}
