/*
    This file is part of nncore.
    
    This code is written by Stefano Merler, <merler@fbk.it>.
    (C) 2008 Fondazione Bruno Kessler - Via Santa Croce 77, 38100 Trento, ITALY.

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/


#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "nn.h"

int compute_nn(NearestNeighbor *nn,int n,int d,double *x[],int y[],
	       int k, int dist)
     /*
       Compute nn model. x,y,n,d are the input data.
       k is the number of NN.
       dist is the adopted distance.

       Return value: 0 on success, 1 otherwise.
     */
{
  int i;

  if(k>n){
    fprintf(stderr,"compute_nn: k must be smaller than n\n");
    return 1;
  }
  
  switch(dist){
  case DIST_SQUARED_EUCLIDEAN:
    break;
  case DIST_EUCLIDEAN:
    break;
  default:
    fprintf(stderr,"compute_nn: distance not recognized\n");
    return 1;
  }

  nn->n=n;
  nn->d=d;
  nn->k=k;
  nn->dist=dist;

  nn->nclasses=iunique(y,n, &(nn->classes));

  if(nn->nclasses<=0){
    fprintf(stderr,"compute_nn: iunique error\n");
    return 1;
  }
  if(nn->nclasses==1){
    fprintf(stderr,"compute_nn: only 1 class recognized\n");
    return 1;
  }

  if(nn->nclasses==2)
    if(nn->classes[0] != -1 || nn->classes[1] != 1){
      fprintf(stderr,"compute_nn: for binary classification classes must be -1,1\n");
      return 1;
    }
  
  if(nn->nclasses>2)
    for(i=0;i<nn->nclasses;i++)
      if(nn->classes[i] != i+1){
	fprintf(stderr,"compute_nn: for %d-class classification classes must be 1,...,%d\n",nn->nclasses,nn->nclasses);
	return 1;
      }

  nn->x=x;
  nn->y=y;

  return 0;

}


int predict_nn(NearestNeighbor *nn, double x[],double **margin)
     /*
       predicts nn model on a test point x. Proportions of neighbours
       for each class will be stored within the array margin 
       (an array of length nn->nclasses). 

       
       Return value: the predicted value on success (-1 or 1 for
       binary classification; 1,...,nclasses in the multiclass case),
       0 on succes with non unique classification, -2 otherwise.
     */
{
  int i,j;
  double *dist;
  int *indx;
  int *knn_pred;
  double one_k;
  int pred_class=-2;
  double pred_n;

  if(!((*margin)=dvector(nn->nclasses))){
    fprintf(stderr,"predict_nn: out of memory\n");
    return -2;
  }
  if(!(dist=dvector(nn->n))){
    fprintf(stderr,"predict_nn: out of memory\n");
    return -2;
  }
  if(!(indx=ivector(nn->n))){
    fprintf(stderr,"predict_nn: out of memory\n");
    return -2;
  }
  if(!(knn_pred=ivector(nn->k))){
    fprintf(stderr,"predict_nn: out of memory\n");
    return -2;
  }

  switch(nn->dist){
  case DIST_SQUARED_EUCLIDEAN:
    for(i=0;i<nn->n;i++)
      dist[i]=euclidean_squared_distance(x,nn->x[i],nn->d);
    break;
  case DIST_EUCLIDEAN:
    for(i=0;i<nn->n;i++)
      dist[i]=euclidean_squared_distance(x,nn->x[i],nn->d);
    break;
  default:
    fprintf(stderr,"predict_nn: distance not recognized\n");
    return -2;
  }

  
  for(i=0;i<nn->n;i++)
    indx[i]=i;
  dsort(dist,indx,nn->n,SORT_ASCENDING);

  for(i=0;i<nn->k;i++)
    knn_pred[i]=nn->y[indx[i]];

  one_k=1.0/nn->k;
  for(i=0;i<nn->k;i++)
    for(j=0;j<nn->nclasses;j++)
      if(knn_pred[i] == nn->classes[j]){
	(*margin)[j] += one_k;
	break;
      }

  pred_class=nn->classes[0];
  pred_n=(*margin)[0];
  for(j=1;j<nn->nclasses;j++)
    if((*margin)[j]> pred_n){
      pred_class=nn->classes[j];
      pred_n=(*margin)[j];
    }
  
  for(j=0;j<nn->nclasses;j++)
    if(nn->classes[j] != pred_class)
      if(fabs((*margin)[j]-pred_n) < one_k/10.0){
	pred_class = 0;
	break;
      }
  
  free_dvector(dist);
  free_ivector(indx);
  free_ivector(knn_pred);
  
  return pred_class;
  
}
