#include <iostream>
#include <vector>
#include <math.h>
#include <fstream>
#include <map>
#include "generateARG.h"
#include <stdio.h>
//#include "basicdefs.h"

using namespace std;

vector<unsigned int> headVec,tailVec;
vector<unsigned int> recNode, intNode; // recombination and internal nodes
//vector<unsigned int> leftBreakPoint;
vector<BreakPoint> bpVec;
vector<char *> edgeLabel;
vector<unsigned int> map2originalSites;  // for keeping track of column numbers


// print break point
void BreakPoint::display() const
{
  printf("(%u,%u)\n",left,right);
}


void fprintNode(unsigned int iNode, ofstream &outgraph){
  outgraph << "node[ id " << iNode << " label " << '"' << " " << '"' << endl;
  //  outgraph << "   graphics[ width 0.0 height 0.0 depth 0.0 ]\n";
  outgraph << "   graphics[ Image[ Type " << '"' << "File" << '"';
  outgraph << " Location " << '"' << "white-ball.gif" << '"' << "] ]\n";
  outgraph << "]\n";
}


void fprintLeaf(unsigned int iNode, ofstream &outgraph){
  outgraph << "node[ id " << iNode << " label " << '"' << iNode+1 << '"' << endl;
  outgraph << "   graphics[ Image[ Type " << '"' << "File" << '"';
  outgraph << " Location " << '"' << "red-ball.gif" << '"' << "] ]\n";
  outgraph << "   vgj[ labelPosition " << '"' << "below" << '"' << " ]\n";
  outgraph << "]\n";
}


void fprintRecNode(unsigned int iNode, const BreakPoint bpoint, ofstream &outgraph){
  outgraph << "node[ id " << iNode;
  outgraph << " label " << '"' << "(" << bpoint.left+1 << "," << bpoint.right+1 << ")" << '"' << endl; 
  outgraph << "   graphics[ Image[ Type " << '"' << "File" << '"';
  outgraph << " Location " << '"' << "blue-ball.gif" << '"' << "] ]\n";
  //  outgraph << "   vgj[ labelPosition " << '"' << "below" << '"' << " ]\n";
  outgraph << "]\n";
}

void fprintEdge(unsigned int head, unsigned int tail, const char *label, ofstream &outgraph){
  outgraph << "edge[" << " label " << '"' << label << '"' << endl;
  outgraph << "   source " << head << " target " << tail<< endl; 
  outgraph << "]\n";
}


void printEdge(unsigned int head, unsigned int tail){
  printf("%u --> %u\t",head+1,tail+1);
}


void humanPrintEdge(unsigned int head, unsigned int tail, ofstream &outgraph){
  outgraph << head+1 << " --> " << tail+1 << "\t";
}

// coalesce indentical sequences and track
// labelVec[i] is the vertex label associated to sequence i in inMat
BinMatrix coalesceTrack(const BinMatrix &inMat, vector<unsigned int> &labelVec, vector<unsigned short> &rmap, ofstream &outgraph){

  BinMatrix reducedMat;

  std::vector<unsigned int> remaining;
  std::vector<unsigned int> needToTest;
  vector<unsigned int> distinct;
  vector<unsigned int> tmpLabelVec;
  unsigned int  tmpLabel=maxOfVec(labelVec);
  unsigned int iR;
  bool coaltest;
  vector<unsigned short> newRmap=rmap;

  //  cout << "labelVec: "; printVec(labelVec);
  if(inMat.rowsize() > 1){
    needToTest=table((unsigned int) 0,inMat.rowsize()-1);
    while(needToTest.size()!=0){
      coaltest=false;
      distinct.push_back(needToTest[0]);  // ref seq
      for(iR=1; iR < needToTest.size(); iR++){
	if(inMat[needToTest[iR]]!=inMat[needToTest[0]]){ // not the same as row=needToTest[0]
	  remaining.push_back(needToTest[iR]);
	}
	else{
	  if(coaltest==false){  // if this is the first match, 
                                // then we need to store the parent vertex
	    tmpLabel+=1;
	    intNode.push_back(tmpLabel);
	    // edge between the new vertex and the ref seq
	    headVec.push_back(tmpLabel); tailVec.push_back(labelVec[needToTest[0]]);
	    edgeLabel.push_back(" ");
	    labelVec[needToTest[0]]=tmpLabel;  // redefining label 
	  }
	  //edge between the new vertex and matched seq
	  headVec.push_back(tmpLabel); tailVec.push_back(labelVec[needToTest[iR]]);
	  edgeLabel.push_back(" ");
	  coaltest=true;
	}
      }
      needToTest=remaining;
      remaining.clear();
    }

    //    reducedMat.mat.resize(distinct.size());
    reducedMat.resize(distinct.size());
    for(iR=0; iR < distinct.size() ; iR++){
      tmpLabelVec.push_back(labelVec[distinct[iR]]);
      reducedMat.mat[iR]=inMat[distinct[iR]];
      rmap[iR]=newRmap[distinct[iR]];
    }
    rmap.resize(distinct.size());
    newRmap.clear();
    labelVec=tmpLabelVec;
    return(reducedMat);
  }
  else return(inMat);
}

// collapse consecutive indentical columns
BinMatrix collapseColTrack(const BinMatrix &inMat){
  vector<unsigned int> newSiteMap;
  BinMatrix reducedMat;
  
  unsigned int countKept=1; // counts the number of columns to be kept
                            // always keep the first column
  vector<unsigned int> keep(inMat.colsize());
  unsigned int iC,iR;

  if(inMat.colsize() > 1){
    keep[0]=0;  // always keep the first column
    for(iC=1; iC < inMat.colsize(); iC++){
      if(inMat.column(iC)!=inMat.column(iC-1) &&
	 inMat.column(iC)!=conjugate(inMat.column(iC-1))){
	keep[countKept]=iC;
	countKept++;
      }
    }

    reducedMat.resize(inMat.rowsize(),countKept);
    for(iC=0; iC < countKept; iC++){
      for(iR=0; iR < inMat.rowsize(); iR++){
	reducedMat.mat[iR][iC]=inMat[iR][keep[iC]];
      }
      newSiteMap.push_back(map2originalSites[keep[iC]]); 
    }
    map2originalSites.clear();
    map2originalSites=newSiteMap;
    return(reducedMat);
  }
  else return(inMat);
}

// returns a new matrix with noninformative sites removed
BinMatrix killNonInfTrack(const BinMatrix &inMat){
  vector<unsigned int> newSiteMap;
  BinMatrix tmpMat;
  unsigned int tmplength=inMat.rowsize()-1;
  unsigned int iCount=0; // counts the number of informative columns
  unsigned int tmpnorm;
  vector<unsigned int> informative(inMat.colsize());
  unsigned int i,j;
  if(inMat.colsize() > 0){
    for(unsigned int iC=0; iC < inMat.colsize(); iC++){
      tmpnorm=inMat.colnorm(iC);
      if(tmpnorm > 1 && tmpnorm < tmplength){ // informative
	informative[iCount]=iC;
	iCount++;
	newSiteMap.push_back(map2originalSites[iC]);
      }
    }
    map2originalSites.clear();
    map2originalSites=newSiteMap;
    tmpMat.resize(inMat.rowsize(),iCount);
    for(j=0; j < iCount; j++){
      for(i=0; i < inMat.rowsize(); i++){
	tmpMat.mat[i][j]=inMat[i][informative[j]];
      }
    }
    return(tmpMat);
  }
  else return(inMat);
}


// coalesce, collapse and remove mutations interatively
BinMatrix reduceTrack(const BinMatrix & inMat, vector<unsigned int> &labelVec,vector<unsigned short> &rmap, ofstream &outgraph)
{
  BinMatrix tempMat=inMat;
  unsigned int RowLength=0, ColLength=0;
  
  while(RowLength != tempMat.rowsize() || 
        ColLength != tempMat.colsize()){
    RowLength=tempMat.rowsize();
    ColLength=tempMat.colsize();
    tempMat=collapseColTrack(killNonInfTrack(coalesceTrack(tempMat,labelVec,rmap,outgraph)));
  }
  return(tempMat);
}

// Here, seqSet is a set of remaining sequences after givenSeq got removed
// computes the minumum number of rec events required to derive a given seq 
// from a set of sequences

void recTrack(const unsigned int row2BRemoved, const BinMatrix & seqSet, vector<unsigned int> &labelVec, ofstream &outGraph)
{ 
  unsigned int position=0, numRec=0, maxMatchLength=0, curMatchLength;
  unsigned int iC,iR,tmpParent;
  vector<unsigned int> parentVec; // records the sequences where givenSeq came from
  unsigned int  firstRecNode=maxOfVec(labelVec)+1;  // first recomb node
  unsigned int  firstPNode, secondPNode;  // first and second parent nodes
  BreakPoint tmpBP;

  // scan along the sequence length and find maximal matches
  while(position < seqSet.colsize()){
    for(iR=0; iR < row2BRemoved; iR++){
      curMatchLength=0;
      for(iC=position; iC < seqSet.colsize(); iC++){
	if(seqSet[iR][iC]==seqSet[row2BRemoved][iC]){
	  curMatchLength++;
	}
	else{
	  break;
	}
      }
      if(curMatchLength > maxMatchLength){
	maxMatchLength=curMatchLength;
	tmpParent=iR;
      }
    }
    for(iR=row2BRemoved+1; iR < seqSet.rowsize(); iR++){
      curMatchLength=0;
      for(iC=position; iC < seqSet.colsize(); iC++){
	if(seqSet[iR][iC]==seqSet[row2BRemoved][iC]){
	  curMatchLength++;
	}
	else{
	  break;
	}
      }
      if(curMatchLength > maxMatchLength){
	maxMatchLength=curMatchLength;
	tmpParent=iR;
      }

    }
    parentVec.push_back(tmpParent);
    position += maxMatchLength;
    if(position < seqSet.colsize()){
      numRec++; // need more recombination
      //      cout << "\nposition = " << position << endl;
      tmpBP.left=map2originalSites[position-1];
      tmpBP.right=map2originalSites[position];
      //      cout << map2originalSites[1];
      //      tmpBP.display();
      bpVec.push_back(tmpBP);
    }
    maxMatchLength=0;
  }

  //  cout << "numRec = " << numRec << endl;

  //  cout << "parentVec: "; printVec(parentVec);
  firstPNode=firstRecNode+numRec;
  secondPNode=firstRecNode+numRec+1;  
  //Note:  pareVec.size() = numRec+1

  /*

   |    |   |   |        |
   |    |   |   |        |
   p1   p2  p3  p4       p*n+1  
   |\  /|  /|  /|       /|
   | \/_|_/_|_/_|_ ... / | 
   | v1 |v2 |v3 |      | |
   |    |   |   |      | |
   s1   s2  s3  s4     | s_n
                       |
                 recombined seq
  */

  recNode.push_back(firstRecNode);
  intNode.push_back(firstPNode);
  intNode.push_back(secondPNode);


  headVec.push_back(firstPNode); tailVec.push_back(firstRecNode);
  edgeLabel.push_back("P");

  headVec.push_back(secondPNode); tailVec.push_back(firstRecNode);
  edgeLabel.push_back("S");

  headVec.push_back(firstPNode); tailVec.push_back(labelVec[parentVec[0]]);
  edgeLabel.push_back(" ");

  headVec.push_back(secondPNode); tailVec.push_back(labelVec[parentVec[1]]);
  edgeLabel.push_back(" ");

  labelVec[parentVec[0]]=firstPNode; // redefining label
  labelVec[parentVec[1]]=secondPNode; // redefining label

  for(iR=1; iR < numRec; iR++){
    recNode.push_back(iR+firstRecNode);
    intNode.push_back(iR+secondPNode);

    headVec.push_back(iR-1+firstRecNode); tailVec.push_back(iR+firstRecNode);
    edgeLabel.push_back("P");
    headVec.push_back(iR+secondPNode); tailVec.push_back(iR+firstRecNode);
    edgeLabel.push_back("S");
    headVec.push_back(iR+secondPNode); tailVec.push_back(labelVec[parentVec[iR+1]]);
    edgeLabel.push_back(" ");

    labelVec[parentVec[iR+1]]=iR+secondPNode; // redefining label
  }

  headVec.push_back(numRec-1+firstRecNode); tailVec.push_back(labelVec[row2BRemoved]);
  edgeLabel.push_back(" ");
}


// modifies row numberd due to removing a row
void modifiedRowNumber(const unsigned short removedRow, vector<unsigned short> & original){
  unsigned short iR;
  for(iR=removedRow+1; iR < original.size(); iR++){
    original[iR]=iR-1;
  }
}
   
unsigned short findPosition(unsigned short originalSeq, const vector<unsigned short> &vec){
  unsigned short iI;
  for(iI=0; iI < vec.size(); iI++){
    if(vec[iI]==originalSeq) return iI;
  }
  return 0;
}

// removedSeq = optimal sequence of removed rows labelled according 
// to the initially reduced data, not the original data
// Need to convert it to row numbers of successively modified data
void generateARG(const BinMatrix &inData, const vector<unsigned short> &removedSeq, char *gfile){

  BinMatrix copyInData;
  copyInData=inData;
  vector<unsigned short> rowMap=table( 0,(unsigned short)( inData.rowsize()-1));

  BinMatrix modifiedData;
  unsigned short iC=0,iRemove=0;
  unsigned int iR;
  vector<unsigned int> labelVec=table((unsigned int) 0,copyInData.rowsize()-1);

  ofstream outGraph(gfile, ios::out);

  map2originalSites.resize(inData.colsize());
  for(iC=0; iC < inData.colsize(); iC++){
    map2originalSites[iC]=iC;
  }

  cout << "\nReconstructing an ARG ... ";
  
  // reduce the data first
  copyInData=reduceTrack(copyInData,labelVec,rowMap,outGraph);

  // need to redefine row numbers
  rowMap=table( 0,(unsigned short)( copyInData.rowsize()-1));

  unsigned short tmpPos;
  while(copyInData.rowsize() > 1){
    if(copyInData.rowsize() >1){
      tmpPos=findPosition(removedSeq[iRemove],rowMap);
      recTrack(tmpPos,copyInData,labelVec,outGraph);
      removeEntry(tmpPos,labelVec);
      copyInData=removeRow(tmpPos,copyInData,rowMap); //rowMap gets modified here
      iRemove++;
    }
    copyInData=reduceTrack(copyInData,labelVec,rowMap,outGraph); //rowMap gets modified here
  }

  cout << "Done." << endl;

  // saving the graph into a file

  // first print edges in human readable form
  outGraph << "# Edges of the graph are as follows: \n";
  for(iR=0; iR < headVec.size(); iR++){
    if(iR%5==0) outGraph << "# ";
    humanPrintEdge(headVec[iR],tailVec[iR],outGraph);
    if((iR+1)%5==0) outGraph << endl;
  }

  outGraph << "\n\ngraph[ directed 1\n"; 
  // print leaves
  for(iR=0; iR < inData.rowsize(); iR++){
    fprintLeaf(iR,outGraph);
  }
  // print internal nodes
  for(iR=0; iR < intNode.size(); iR++){
    fprintNode(intNode[iR],outGraph);
  }
  // print recombination nodes
  for(iR=0; iR < recNode.size(); iR++){
    fprintRecNode(recNode[iR],bpVec[iR],outGraph);
  }
  // print edges
  for(iR=0; iR < headVec.size(); iR++){
    fprintEdge(headVec[iR],tailVec[iR],edgeLabel[iR],outGraph);
  }

  outGraph << "]\n";

  outGraph.close();
}
