#include <Scatter.h>
#include <iostream.h>

Scatter::Scatter () {
  deflectorMass   = DEFAULT_DM;
  deflectorRadius = DEFAULT_DR;

  catchRadius     = DEFAULT_CR;
  timerRate       = DEFAULT_TR;

  root            = new SoSeparator();
  deflector       = makeDeflector();
  currPSet        = NULL;
  oldPSets        = NULL;
}

void Scatter::clear () {
  root->removeAllChildren();
  pset.nodeCreated = false;
}

void Scatter::deleteUnscattered () {
  int i, j, numsets, num;
  SoParticle *part;
  SoSeparator *sep, *particles;
  
  if (root->getNumChildren() > 0) {
    particles = (SoSeparator *)currPSet->getChild(2);
    num = particles->getNumChildren();
    for (i = 0; i < num; i++) {
      part = (SoParticle *)particles->getChild(i);
      if (part->getLocation().length() < catchRadius) {
	particles->removeChild(i);
	num--;
	i--;
      }
    }

    numsets = oldPSets->getNumChildren();
    for (j = 0; j < numsets; j++) {
      sep = (SoSeparator *)oldPSets->getChild(j);
      particles = (SoSeparator *)sep->getChild(2);
      num = particles->getNumChildren();
      for (i = 0; i < num; i++) {
	part = (SoParticle *)particles->getChild(i);
	if (part->getLocation().length() < catchRadius) {
	  particles->removeChild(i);
	  num--;
	  i--;
	}
      }
    }
  }    
}

SoSeparator* Scatter::makeScene () {
  clear();

  SoComplexity *comp = new SoComplexity();
  comp->value.setValue(1.0);
  root->addChild(comp);

  root->addChild(deflector);

  currPSet = pset.getNewSet();
  root->addChild(currPSet);

  oldPSets = new SoSeparator();
  root->addChild(oldPSets);

  return root;
}

void Scatter::newParticleSet () {
  if (root->getNumChildren() == 0) {
    deflector = makeDeflector();
    root = makeScene();
  } else {
    deflector = makeDeflector();
    root->replaceChild(1, deflector);
    oldPSets->addChild(currPSet);
    currPSet = pset.getNewSet();
    root->replaceChild(2, currPSet);
  }
}

SoParticle* Scatter::makeDeflector () {
  float i, j, k, distance;
  SbVec3f v1, final, pos;

  pos = SbVec3f(catchRadius, 0.0, 0.0);
  while (1) {
    i = drand48()*2.0 - 1.0;
    j = drand48()*2.0 - 1.0;
    k = drand48()*2.0 - 1.0;
    distance = sqrt(i*i + j*j + k*k);
    if (distance < 0.75)
      break;
  }
  v1 = SbVec3f(i, j, k);

  final = pos - v1;
  final.normalize();
  final.negate();
  final /= 10.0;

  SoParticle *def = new SoParticle();
  def->setLocation(pos);
  def->setRadius(deflectorRadius);
  def->mass     = deflectorMass;
  def->velocity = final;

  return def;
}

void Scatter::updateVelocities (SoParticle *source, SoSeparator *particles,
                                SoMFInt32 &collisions) {
  SoParticle *part;
  SbVec3f norm, v1i, v2i, v1f, spos, ppos, distance;
  float m1, m2, e = 1.0, radiisum, diff;
  int num;
  
  v1i  = source->velocity;
  m1   = source->mass;
  spos = source->getLocation();
  v1f  = SbVec3f(0.0, 0.0, 0.0);

  num = collisions.getNum();
  if (num > 0) {
    for (int i = 0; i < num; i++) {
      part = (SoParticle *)particles->getChild(collisions[i]);

      ppos = part->getLocation();

      /*
      distance = spos - part->getLocation();  
      radiisum = source->getRadius() + part->getRadius();

      if (distance.length() < radiisum) {
	diff = radiisum - distance.length();
	v2i = part->velocity;
	if (v2i.length() > 0.0)
	  ppos = part->getLocation() - v2i * (diff / v2i.length());
	if (v1i.length() > 0.0)
	  spos -= v1i * (diff / v1i.length());
      }
      */
      v2i = part->velocity;
      m2  = part->mass;

      norm = ppos - spos;
      norm.normalize();

      part->velocity = v2i + ( (m1 * (norm.dot(v1i - v2i))) /
	  		       (m1 + m2) ) * (1.0 + e) * norm;
      
      v1f += v1i - ((m2*(norm.dot(v1i -v2i))) / (m1 + m2))*(1.0 + e)*norm;
    }
    source->velocity = v1f / (float)num;
  }
}

void Scatter::update () {
  int num, i, j, count;
  float distance, radiisum;
  SoParticle *part, *part1, *part2;
  SbVec3f newPart1Vel, newPart2Vel, diff;
  SoMFVec3f newPositions, newVelocities;
  SoMFInt32 collisions;
  SoSeparator *particles;

  particles = (SoSeparator *)currPSet->getChild(2);
  
  num = particles->getNumChildren();
  for (i = -1; i < num; i++) {
    
    if (i < 0)
      part1 = deflector;
    else 
      part1 = (SoParticle *)(particles->getChild(i));

    collisions.deleteValues(0, -1);

    count = 0;
    for (j = i + 1; j < num; j++) {
      part2 = (SoParticle *)(particles->getChild(j));
      diff     = part1->getLocation() - part2->getLocation();
        distance = diff.length();

        radiisum = part1->getRadius();
	radiisum += part2->getRadius();

        if (distance <= radiisum) {
          collisions.set1Value(count, j);
	  count++;
        }
    }

    updateVelocities(part1, particles, collisions);

    if (part1->getLocation().length() > catchRadius)
      part1->velocity = SbVec3f(0.0, 0.0, 0.0);
    else
      part1->moveParticle();
  }
}
