#include "RayTracerStudent.h"
#include "Transformer3DStudent.h"

class RayTracerStudent::uber_object: public object
{
public:
                                                                                                                             
        list<object*> objects;
                                                                                                                             
        uber_object(): object(0) {}//ends constructor
                                                                                                                             
        ~uber_object()
        {
                while (!objects.empty())
                {
                        delete objects.back();
                        objects.pop_back();
                }//ends while loop
        }//ends destructor
                                                                                                                             
        bool test() const
        {
        	return true;
        }//ends function test
                                                                                                                             
        void insert(object* temp_object)
        {
                if (objects.empty())
                {
                        objects.push_back(temp_object);
                }//ends if statement
                else
                {
                        list<object*>::iterator object_iterator = objects.end();
                        if(object_iterator == objects.end())
                        {
                                objects.push_back(temp_object);
                        }//ends if statement
                        else if((*object_iterator)->test())
                        {
                                ((uber_object*)*object_iterator)->insert(temp_object);
                        }//ends else statement
                        else
                        {
                                uber_object* new_object = new uber_object();
                                new_object->objects.push_back(*object_iterator);
                                new_object->objects.push_back(temp_object);
                                objects.erase(object_iterator);
                                objects.push_back(new_object);
                        }//ends else statement
                }//ends else statement
        }//ends function insert
                                                                                                                             
        bool intersect(const Ray& ray, float& u, Point3f& p, Point3f& n, Material& mat, int& sID, bool& in) const
        {
                bool test1, test2;
		test1 = false;
		test2 = false;
		list<object*>::const_iterator object_iterator = objects.begin();
                while(object_iterator != objects.end())
                {
                        test2 = (*object_iterator)->intersect(ray, u, p, n, mat, sID, in);
                        test1 = test1 || test2;
			object_iterator++;
                }//ends for loop
                return test1;
        }// ends method intersect

        bool check_shadow(const Ray& ray, float u) const
        {
		list<object*>::const_iterator object_iterator = objects.begin();
                while(object_iterator != objects.end())
                {
                        if((*object_iterator)->check_shadow(ray, u))
                        {
                                return true;
                        }//ends if statement
			object_iterator++;
                }//ends for loop
                return false;
        }//ends method check_shadow
};

int object::sid_tracker = 1;

sphere::sphere(const Point3f& c, float r, const Material& m, const Texture* tex): center(c), radius(r), material(m), texture(tex){}

bool sphere::intersect(const Ray& ray, float& u, Point3f& p, Point3f& n, Material& mat, int& sID, bool& in) const
{
	bool interior;
	Point2i size;
	Point2f texture_point;
	Point3f object_center;
	float object_center_abs, radius_squared, ray_object_dot, approach, approach_test;

	object_center = center - ray.o;
	object_center_abs = object_center.abs2();
	radius_squared = radius*radius;
	interior = (ray.sID==this->sID)?(ray.in):(object_center_abs <= radius_squared);
	ray_object_dot = object_center.dot(ray.d);

	if(!interior && (ray_object_dot < 0))
	{
		return false;
	}//ends if statement

	approach = radius_squared - object_center_abs + ray_object_dot*ray_object_dot;

	if(approach < 0)
	{
		return false;
	}//ends if statement

	approach_test = interior?(ray_object_dot + sqrt(approach)):(ray_object_dot - sqrt(approach));

	if(approach_test >= u)
	{
		return false;
	}//ends if statement
	
	sID = this->sID;
        in = interior;
	u = approach_test;
	p = ray(u);
	n = (p-center).normalized();
	mat = material;

	if(interior)
	{
		mat.i = 1/material.i;
		mat.t = ColorRGBf::white;
		n = -n;
		mat.d = ColorRGBf::black;
		mat.s = ColorRGBf::black;
	}//ends if statement
	else
	{
		if(texture != NULL)
		{
			size = texture->bounds().size();
			Point2f texture_point;
			texture_point.x = ((atan2(n.x, n.z)/(2*pi))+0.5)*size.x;
			texture_point.y = ((atan2(n.y, sqrt(n.x*n.x + n.z*n.z))/pi)+0.5)*size.y;
			mat.d = (*texture)[texture_point].inFrontOf(mat.d);
		}//ends if statement
	}//ends else statement

	return true;
}//ends function intersect
	
bool sphere::check_shadow(const Ray& ray, float u) const
{
	float object_center_abs, radius_squared, ray_object_dot, approach, temp_float;
	Point3f object_center;
	
	object_center = center - ray.o;
	object_center_abs = object_center.abs2();
	radius_squared = radius*radius;
	if ((ray.sID==this->sID)?(ray.in):(object_center_abs <= radius_squared))
	{
		return true;
	}//ends if statement
	ray_object_dot = object_center.dot(ray.d);
	if (ray_object_dot < 0)
	{
		return false;
	}//ends if statement
	approach = radius_squared - object_center_abs + ray_object_dot*ray_object_dot;
	if (approach < 0)
	{
		return false;
	}//ends if statement
	temp_float = ray_object_dot - sqrt(approach);
	if ((temp_float >= u)||(temp_float <= 0))
	{
		return false;
	}//ends if statement
	return true;
}//ends method check_shadow

triangle::triangle(int sID, const Vertex& v0, const Vertex& v1, const Vertex& v2, const Material& m, const Texture* tex): object(sID), material(m), texture(tex)
{
	Point3f cross_product;
	vertex_array[0] = v0;
	vertex_array[1] = v1;
	vertex_array[2] = v2;
	cross_product = (v1.p - v0.p).cross(v2.p - v0.p);

	if((abs(cross_product.x) >= abs(cross_product.y)) && (abs(cross_product.x) >= abs(cross_product.z)))
	{
		width = cross_product.x;
		orientation = 1;
	}//ends if statement
	else if((abs(cross_product.y) >= abs(cross_product.x)) && (abs(cross_product.y) >= abs(cross_product.z)))
	{
		width = cross_product.y;
		orientation = 2;
	}//ends else if statement
	else
	{
		width = cross_product.z;
		orientation = 3;
	}//ends else statement

	plane = Plane3f(v0.p, v1.p, v2.p);
}//ends method triangle

bool triangle::intersect(const Ray& ray, float& u, Point3f& p, Point3f& n, Material& mat, int& sID, bool& in) const
{
	bool inside;
	float plane_square, plane_ray_dist, temp_float;
	Point3f calc_point, temp_point, new_norm_point;
	
	plane_square = plane.a*ray.d.x + plane.b*ray.d.y + plane.c*ray.d.z;
	plane_ray_dist = plane.dist(ray.o);
	temp_float = -(plane_ray_dist/plane_square);
	inside = (plane_ray_dist <= 0);
	temp_float = -(plane_ray_dist/plane_square);

	if(temp_float >= u || temp_float <= 0)
	{
		return false;
	}//ends if statement

	calc_point = ray(temp_float);

	if(orientation == 1)
	{
		temp_point.x = ((vertex_array[1].p.y - calc_point.y)*(vertex_array[2].p.z - calc_point.z)-(vertex_array[1].p.z - calc_point.z)*(vertex_array[2].p.y - calc_point.y));
		temp_point.y = ((vertex_array[2].p.y - calc_point.y)*(vertex_array[0].p.z - calc_point.z)-(vertex_array[2].p.z - calc_point.z)*(vertex_array[0].p.y - calc_point.y));
		temp_point.x/=width;
		temp_point.y/=width;
		temp_point.z = 1 - temp_point.x - temp_point.y;
	}//ends if statement
	else if(orientation == 2)
	{
		temp_point.x = ((vertex_array[1].p.z - calc_point.z)*(vertex_array[2].p.x - calc_point.x)-(vertex_array[1].p.x - calc_point.x)*(vertex_array[2].p.z - calc_point.z));
		temp_point.y = ((vertex_array[2].p.z - calc_point.z)*(vertex_array[0].p.x - calc_point.x)-(vertex_array[2].p.x - calc_point.x)*(vertex_array[0].p.z - calc_point.z));
		temp_point.x/=width;
		temp_point.y/=width;
		temp_point.z = 1 - temp_point.x - temp_point.y;
	}//ends else if statement
	else
	{
		temp_point.x = ((vertex_array[1].p.x - calc_point.x)*(vertex_array[2].p.y - calc_point.y)-(vertex_array[1].p.y - calc_point.y)*(vertex_array[2].p.x - calc_point.x));
		temp_point.y = ((vertex_array[2].p.x - calc_point.x)*(vertex_array[0].p.y - calc_point.y)-(vertex_array[2].p.y - calc_point.y)*(vertex_array[0].p.x - calc_point.x));
		temp_point.x/=width;
		temp_point.y/=width;
		temp_point.z = 1 - temp_point.x - temp_point.y;
	}//ends else statement
	if((temp_point.x < 0) || (temp_point.y < 0) || (temp_point.z < 0))
	{
		return false;
	}//ends if statement
	new_norm_point = vertex_array[0].n*temp_point.x + vertex_array[1].n*temp_point.y + vertex_array[2].n*temp_point.z;
	if(inside)
	{
		new_norm_point = -new_norm_point;
	}//ends if statement
	if(new_norm_point.dot(ray.d) >= 0)
	{
		return false;
	}//ends if statement

	u = temp_float;
	p = calc_point;
	n = new_norm_point.normalized();
	mat = material;

	if(inside)
	{
		mat.i = 1/material.i;
		mat.d = mat.s = ColorRGBf::black;
		mat.t = ColorRGBf::white;
	}//ends if statement
	else
	{
		mat.d = ColorRGBf(vertex_array[0].c*temp_point.x + vertex_array[1].c*temp_point.y + vertex_array[2].c*temp_point.z);
		if(texture != NULL)
		{
			mat.d = (*texture)[vertex_array[0].t*temp_point.x + vertex_array[1].t*temp_point.y + vertex_array[2].t*temp_point.z].inFrontOf(mat.d);
		}//ends if statement
	}//ends else statement

	sID = this->sID;
	in = inside;
	return true;
}//ends method intersect

bool triangle::check_shadow(const Ray& ray, float u) const
{
	bool inside;
	float plane_square, plane_ray_dist, temp_float;
	Point3f temp_point, calc_point, new_norm_point;
	
	plane_square = plane.a*ray.d.x + plane.b*ray.d.y + plane.c*ray.d.z;
	plane_ray_dist = plane.dist(ray.o);
	inside = (plane_ray_dist <= 0);
	if((ray.sID==this->sID) && (ray.in != inside))
	{
		return false;
	}//ends if statement

	temp_float = -(plane_ray_dist/plane_square);

	if(temp_float >= u || temp_float <= 0)
	{
		return false;
	}//ends if statement

	calc_point = ray(temp_float);

	if(orientation == 1)
	{
		temp_point.x = ((vertex_array[1].p.y - calc_point.y)*(vertex_array[2].p.z - calc_point.z)-(vertex_array[1].p.z - calc_point.z)*(vertex_array[2].p.y - calc_point.y));
		temp_point.y = ((vertex_array[2].p.y - calc_point.y)*(vertex_array[0].p.z - calc_point.z)-(vertex_array[2].p.z - calc_point.z)*(vertex_array[0].p.y - calc_point.y));
		temp_point.x/=width;
		temp_point.y/=width;
		temp_point.z = 1 - temp_point.x - temp_point.y;
	}//ends if statement
	else if (orientation == 2)
	{
		temp_point.x = ((vertex_array[1].p.z - calc_point.z)*(vertex_array[2].p.x - calc_point.x)-(vertex_array[1].p.x - calc_point.x)*(vertex_array[2].p.z - calc_point.z));
		temp_point.y = ((vertex_array[2].p.z - calc_point.z)*(vertex_array[0].p.x - calc_point.x)-(vertex_array[2].p.x - calc_point.x)*(vertex_array[0].p.z - calc_point.z));
		temp_point.x/=width;
		temp_point.y/=width;
		temp_point.z = 1 - temp_point.x - temp_point.y;
	}//ends else if statement
	else
	{
		temp_point.x = ((vertex_array[1].p.x - calc_point.x)*(vertex_array[2].p.y - calc_point.y)-(vertex_array[1].p.y - calc_point.y)*(vertex_array[2].p.x - calc_point.x));
		temp_point.y = ((vertex_array[2].p.x - calc_point.x)*(vertex_array[0].p.y - calc_point.y)-(vertex_array[2].p.y - calc_point.y)*(vertex_array[0].p.x - calc_point.x));
		temp_point.x/=width;
		temp_point.y/=width;
		temp_point.z = 1 - temp_point.x - temp_point.y;
	}//ends else statement

	if((temp_point.x < 0) || (temp_point.y < 0) || (temp_point.z < 0))
	{
		return false;
	}//ends if statement

	new_norm_point = vertex_array[0].n*temp_point.x + vertex_array[1].n*temp_point.y + vertex_array[2].n*temp_point.z;

	if(inside == (new_norm_point.dot(ray.d) < 0))
	{
		return false;
	}//ends if statement

	return true;

}//ends method check_shadow

RayTracerStudent::RayTracerStudent(const Transformer2D* transformer2, const Transformer3D* transformer3): RayTracer(transformer2, transformer3)
{
	objects = new uber_object();
	info = new render_info();
}//ends constructor

RayTracerStudent::~RayTracerStudent()
{
	delete objects;
	delete info;
}//ends destructor

void RayTracerStudent::addTriangle(const Vertex& v0, const Vertex& v1, const Vertex& v2, const Material& m, const Texture* tex)
{
	objects->insert(new triangle(object::getNextSID(), v0, v1, v2, m, tex));
}//ends method addTriangle

void RayTracerStudent::addSphere(const Point3f& c, float r, const Material& mat, const Texture* tex, const Transform3f& t)
{
	sphere* s = new sphere(c, r, mat, tex);
	objects->insert(s);
}//ends method addSphere

void RayTracerStudent::removeAll()
{
	materials.clear();
	lights.clear();
	delete objects;
	objects = new uber_object();
}//ends method removeAll

bool RayTracerStudent::intersect(const Ray& ray, float& u, Point3f& p, Point3f& n, Material& mat, int& sID, bool& in) const
{
	return objects->intersect(ray, u, p, n, mat, sID, in);
}//ends method intersect

bool RayTracerStudent::check_shadow(const Ray& ray, float u) const
{
	return objects->check_shadow(ray, u);
}//ends method check_shadow

void RayTracerStudent::getVisibleLights(const Point3f& p, const Point3f& n, int sID, bool in, list<Light*>& ls) const
{
	float distance;	

	ls.clear();

	list<Light*>::const_iterator light_iterator = lights.begin();
	while(light_iterator != lights.end())
	{
		Ray ray(p, Point3f(), sID, in);
		(*light_iterator)->sourceAt(p, ray.d, distance);
		if(ray.d.dot(n) > 0)
		{
			if(!check_shadow(ray, distance))
			{
				ls.push_back(*light_iterator);
			}//ends if statement
		}//ends if statement
		light_iterator++;
	}//ends while loop
}//ends method getVisibleLights

ColorRGBf RayTracerStudent::getIllumination(const Ray& ray, const Point3f& p, const Point3f& n, const Material& mat, const list<Light*>& ls) const
{
	
	float d, dot_product;
	ColorRGBf ambient_light;
	ColorRGBf black;
	Point3f l;

	ambient_light = ambient;
	black = ColorRGBf::black;

	list<Light*>::const_iterator light_iterator = ls.begin();
	while(light_iterator != ls.end())
	{
		ColorRGBf li;
		(*light_iterator)->allAt(p, l, d, li);
		dot_product = n.dot(l);
		if(dot_product > 0)
		{
			ambient_light += li * dot_product;
			dot_product = -(ray.d.dot(n*(dot_product*2)-l));
			if(dot_product > 0 && mat.sp >= 1)
			{
				black += li * pow(dot_product, mat.sp);
			}//ends if statement
		}//ends if statement
		light_iterator++;
	}//ends while loop
	return ambient_light*mat.d + black*mat.s;
}//ends method getIllumination

Point3f RayTracerStudent::getReflection(const Ray& ray, const Point3f& n, const Material& mat) const
{
	float dot_product;
	
	dot_product = n.dot(ray.d);
	dot_product *= 2;
	
	return ray.d - n*dot_product;
	//return ray.d - n*(n.dot(ray.d)*2);
}//ends method getReflection

Point3f RayTracerStudent::getTransmission(const Ray& ray, const Point3f& n, const Material& mat, float& ir) const
{
	float trans, dot_product, square, root, temp_float;
	Point3f temp_point;
	
	trans = 1/mat.i;
	dot_product = -n.dot(ray.d);
	square = (1-(dot_product*dot_product));
	square = square * trans * trans;
	square = 1-square;
	if(square <= 0)
	{
		ir = 1;
		return Point3f();
	}//ends if statement
	else
	{
		root = sqrt(square);
		ir = 0;
		temp_point = ray.d*trans;
		temp_float = dot_product * trans;
		temp_float = root - temp_float;
		
		return temp_point - n*temp_float;
	}//ends else statement
}//ends method getTransmission

bool RayTracerStudent::startRender(Display& display, RayCamera* camera, const Transform2f& viewt, char* msg)
{
	info->inverted = viewt.inverse();
	info->start(display.bounds().size(), camera);
	return true;
}//ends method startRender

bool RayTracerStudent::renderSlice(Display& display, unsigned long msecs, char* msg)
{
	Point2i size, x_slice;
	Point2f inverter;
	StopWatch timer;

	if(!info->running)
	{
		return false;
	}//ends if statement

	size = info->bool_map.bounds().size();

	while(timer.elapsed() < msecs)
	{
		if(info->running == false)
		{
			return false;
		}//ends if statement

		 x_slice.x = info->x_line;

		for( x_slice.y = 0;  x_slice.y < size.y;  x_slice.y++)
		{
			inverter =  x_slice;
			transformer2->apply(info->inverted, inverter);
			display[x_slice] = ColorRGBb(getColor(info->camera->rayAt(inverter)));
		}//ends for loop

		info->x_line++;

		if(info->x_line == size.x)
		{
			info->halt();
		}//ends if statement
	}//ends while loop

	return info->running;
}//ends mehtod renderSlice

void RayTracerStudent::abortRender()
{
	info->halt();
}//ends method abortRender
