//----------------------------------------------------------------------------
//	Copyright (C) 2002, 2003, 2004 Humboldt-Universitaet zu Berlin
//
//	This library is free software; you can redistribute it and/or
//	modify it under the terms of the GNU Lesser General Public
//	License as published by the Free Software Foundation; either
//	version 2.1 of the License, or (at your option) any later version.
//
//	This library is distributed in the hope that it will be useful,
//	but WITHOUT ANY WARRANTY; without even the implied warranty of
//	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
//	Lesser General Public License for more details.
//
//	You should have received a copy of the GNU Lesser General Public
//	License along with this library; if not, write to the Free Software
//	Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
//
//----------------------------------------------------------------------------
/**	\file Continuous.cpp

	\author Ralf Gerstenberger
	<!-- [\author <author>]* -->

	\date created at 2003/06/15

	\brief Implementation of classes in Continuous.h

	\sa Continuous.h

	<!-- [detailed description] -->

	<!-- [\todo {todos for this file}]* -->

	\since 1.0
*/

#include <odemx/base/Continuous.h>
#include <odemx/util/ErrorHandling.h>

#include <fstream>
#include <iostream>
#include <iomanip>
#include <cmath>

using namespace odemx;

Continuous::Continuous(Simulation* s, Label l, int d, ContinuousObserver* o) :
	Process(s, l, o),
	Observable<ContinuousObserver>(o),
	state(d), rate(d),
	error_vector(d), initial_state(d),
	slope_1(d), slope_2(d), slope_3(d)
{
	dimension=d;
	stepLength=minStepLength=0.01;
	maxStepLength=0.1;

	relative=0;
	errorLimit=0.1;

	stopped=0;

	stopTime=HUGE_VAL;
	stopCond=0;

	// trace
	getTrace()->mark(this, markCreate);

	// observer
	_obsForEach(ContinuousObserver, Create(this));
}

Continuous::~Continuous() {
	// trace
	getTrace()->mark(this, markDestroy);

	// observer
	_obsForEach(ContinuousObserver, Destroy(this));
}

void Continuous::interrupt() {
	// stop integration
	stopped=2;

	Process::interrupt();
}

void Continuous::setStepLength(double min, double max) {
	if ( (min<=0)||(max<=0)||(min>=max) ) {
		// Error: invalid step length boundaries
		error("setStepLength(): invalid step length boundaries;");

		if (min<0) min=-min;
		if (max<0) max=-max;
		if (min==0) min=0.01;
		if (max==0) max=0.1;
		if (min>=max) max*=10.0;
	}

	minStepLength=min;
	maxStepLength=max;
}

void Continuous::addPeer(Process* d) {
	if (d==0) return;
	peers.push_back(d);

	// trace
	getTrace()->mark(this, markAddPeer, d);

	// observer
	_obsForEach(ContinuousObserver, AddPeer(this, d));
}

void Continuous::delPeer(Process* d) {
	if (d==0) return;
	peers.remove(d);

	// trace
	getTrace()->mark(this, markRemovePeer, d);

	// observer
	_obsForEach(ContinuousObserver, RemovePeer(this, d));
}

int Continuous::integrate(SimTime timeEvent, Condition stateEvent) {
	if (getProcessState()!=Process::CURRENT) {
		// error: integrate called by another Process
		error("integrate():Continuous is not current");
	}

	// trace
	getTrace()->mark(this, markBeginIntegrate);

	// observer
	_obsForEach(ContinuousObserver, BeginIntegrate(this));

	// initialise stop condition
	stopTime=timeEvent;
	if (stopTime==0) stopTime=HUGE_VAL;
	stopCond=stateEvent;

	stopped=0;
	stepLength=maxStepLength;
	time=getCurrentTime();

	// reduce priority
	setPriority( getPriority()-1 );

	while (!stopIntegrate() &&
		   getCurrentTime()<stopTime &&
		   !stopped)
	{
		// save initial state
		initial_state=state;

		// set step length
		if (stepLength>maxStepLength) stepLength=maxStepLength;

		// synchronise with timeEvent
		SimTime stepTime=getCurrentTime()+stepLength;
		if (stepTime>stopTime) {
			stepLength=stopTime-stepTime;
			stepTime=stopTime;
		}

		// synchronise with peers
		// -> reduce stepLength
		std::list<Process*>::iterator i;
		for (i=peers.begin(); i!=peers.end(); ++i)
		{
			Process* p=*i;
			if (p->getProcessState()==Process::RUNNABLE)
			{
				SimTime et=p->getExecutionTime();
				if (stepTime>et) stepTime=et;
			}
			assert(stepTime>=getCurrentTime());
		}
		stepLength=stepTime-getCurrentTime();

		// integrate
		takeAStep(stepLength);

		// handle integration errors
		while (reduce())
		{
			// while integration error to great:

			// reduce stepLength
			stepLength/=2.0;

			// reset state
			state=initial_state;

			// try again
			takeAStep(stepLength);
		} // end while reduce()

		// handle state events
		if (stopIntegrate())
		{
			// state event -> Binary search for event time
			binSearch();
		}
		else
		{
			// waiting time
			time+=stepLength;
			holdFor(stepLength);

			// trace
			getTrace()->mark(this, markNewValidState);

			// observer
			_obsForEach(ContinuousObserver, NewValidState(this));

			// try to increase stepLength
			stepLength*=2.0;
		}
	} // end main while loop

	// reset priority
	setPriority( getPriority()+1 );

	// trace
	getTrace()->mark(this, markEndIntegrate);

	// observer
	_obsForEach(ContinuousObserver, EndIntegrate(this, stopped));

	return stopped;
}

void Continuous::binSearch() {
	// Binary search (BS)
	double sl=stepLength/2.0; // step length for BS

	if (sl<minStepLength)
	{
		// step length already to small
		// -> accept results

		stopped=1;

		// waiting time
		time+=stepLength;
		holdFor(stepLength);

		// trace
		getTrace()->mark(this, markNewValidState);

		// observer
		_obsForEach(ContinuousObserver, NewValidState(this));
	}
	else {
		state=initial_state; // reset state
		while (sl>=minStepLength) {
			// integrate
			takeAStep(sl);

			if (stopIntegrate())
			{
				// step length still to long
				// reduce sl and try again
				sl/=2.0;

				// reset state
				state=initial_state;
			}
			else
			{
				// accept step
				time+=sl;
				holdFor(sl);

				// trace
				getTrace()->mark(this, markNewValidState);

				// observer
				_obsForEach(ContinuousObserver, NewValidState(this));

				initial_state=state;

				// and try to get closer to the state event
				sl/=2.0;
			}
		}

		stopped=1;
	}
}

double Continuous::errorNorm() {
	unsigned int i;
	double keep_value=0.0;
	for (i=0; i<=getDimension()-1; i++)
		if (error_vector[i]>keep_value)keep_value=error_vector[i];
	return keep_value;
}

int Continuous::reduce() {
	error_value=errorNorm();

	if(error_value<=errorLimit) return 0;
	else {
		if(stepLength<(2.0*minStepLength)) {
			// Error: error to high, but can't reduce step length further
			error("reduce(): can't reduce step length");
			return 0;
		}
		else return 1;
	}
	return 0;
}

void Continuous::takeAStep(double h) {
	double hdiv2,hdiv3,hdiv6,hdiv8;
	int i;
	int n=getDimension()-1;
	hdiv2=h/2.0;
	hdiv3=h/3.0;
	hdiv6=h/6.0;
	hdiv8=h/8.0;

	derivatives (time);
	for (i=0; i<=n; i++) {
		slope_1[i]=rate[i];
		state[i]=initial_state[i] + hdiv3*slope_1[i];
	};

	derivatives (time+hdiv3);
	for (i=0; i<=n; i++) {
		slope_2[i]=rate[i];
		state[i]=initial_state[i] + hdiv6*(slope_1[i] + slope_2[i]);
	};

	derivatives (time+hdiv3);
	for (i=0; i<=n; i++) {
		slope_2[i]=rate[i];
		state[i]=initial_state[i] + hdiv8*(slope_1[i] + 3.0*slope_2[i]);
	};

	derivatives (time+hdiv2);
	for ( i=0; i<=n; i++) {
		slope_3[i]= rate[i];
		state[i]=initial_state[i] + hdiv2*(slope_1[i] - 3.0*slope_2[i] + 4.0*slope_3[i]);
		error_vector[i]= state[i];
	};

	derivatives (time+h);
	for (i=0; i<=n; i++) {
		slope_2[i]= rate[i];
		state[i] = initial_state[i] + hdiv6*(slope_1[i] + 4.0*slope_3[i] + slope_2[i]);

		if (relative)
		{
			if (error_vector[i])
				error_vector[i]=fabs(0.2*(1.0-state[i]/error_vector[i]));
		}
		else
			error_vector[i]=fabs(0.2*(error_vector[i]-state[i]));
	};
}

bool Continuous::stopIntegrate() {
	bool result=false;
	if (stopCond!=0) {
		result=(this->*stopCond)();
	}

	return result;
}

const MarkTypeId Continuous::baseMarkId=1000;

const MarkType Continuous::markCreate("create", baseMarkId+1, typeid(Continuous));
const MarkType Continuous::markDestroy("destroy", baseMarkId+2, typeid(Continuous));

const MarkType Continuous::markAddPeer("addPeer", baseMarkId+3, typeid(Continuous));
const MarkType Continuous::markRemovePeer("removePeer", baseMarkId+4, typeid(Continuous));

const MarkType Continuous::markBeginIntegrate("beginIntegrate", baseMarkId+5, typeid(Continuous));
const MarkType Continuous::markEndIntegrate("endIntegrate", baseMarkId+6, typeid(Continuous));

const MarkType Continuous::markNewValidState("newValidState", baseMarkId+7, typeid(Continuous));

const TagId Continuous::baseTagId=1000;

ContuTrace::ContuTrace(Continuous* contu, const char* fileName) : out(0), firstTime(true)
{
	using namespace std;

	if (contu!=0)
		contu->Observable<ContinuousObserver>::addObserver(this);

	if (fileName!=0)
		openFile(fileName);
	else if (contu!=0) {
		string fn=contu->getLabel();
		fn += "_trace.txt";

		openFile(fn.c_str());
	}
}

ContuTrace::~ContuTrace()
{
	using namespace std;

	if (out==0 || out==&cout) return;

	static_cast<ofstream*>(out)->close();
	delete out;
}

void ContuTrace::onNewValidState(Continuous* sender)
{
	using namespace std;

	assert(sender!=0);
	unsigned int i=0;

	if (firstTime) {
		if (out==0) {
			string fn=sender->getLabel();
			fn += "_trace.txt";
			openFile(fn.c_str());
		}

		*out << endl;
		*out << setw(60) << "T R A C E   F I L E   F O R  " << sender->getLabel() << endl << endl;
		*out << setw(13) << "Time";
		*out << setw(13) << "Steplength";
		*out << setw(13) << "Error";

		for (i=0; i<sender->getDimension(); ++i)
			*out << setw(15) << "state[" << i << "]";

		for (i=0; i<sender->getDimension(); ++i)
			*out << setw(15) << "rate[" << i << "]";

		*out << endl;

		firstTime=false;
	}

	(*out).setf(std::ios::showpoint);
	(*out).setf(std::ios::fixed);
	*out << setw(13) << setprecision(7) << sender->getCurrentTime();
	*out << setw(13) << setprecision(7) << sender->getStepLength();
	*out << setw(13) << setprecision(7) << sender->error_value;

	for (i=0; i<sender->getDimension(); i++)
		*out << setw(17) << setprecision(7) << sender->state[i];

	for (i=0; i<sender->getDimension(); i++)
		*out << setw(17) << setprecision(7) << sender->rate[i];

	*out << endl;
}

void ContuTrace::openFile(const char* fileName)
{
	using namespace std;

	out = new ofstream(fileName);
	if (out==0 || !(*out)) {
		// Error: cannot open file
		cerr << "TestObserver cannot open file; sending output to stdout." << endl;
		out = &cout;
	}
}

