%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%
%% df is a tool for plotting a direction field and solutions of a
%% first order differential equation
%%
%%     y' = f(x,y)
%%
%%    df( ode, x0, x1, y0, y1, vals)
%%
%%    ode: a function handle describing the right-hand side f(x,y)
%%  x0,x1: initial and final times for the plot window
%%  y0,y1: minimum and maximum values of the solution for the plot window
%%   vals: A (n x 2) matrix of initial data for solutions to be plotted, of the form
%%         [ X0 , Y0; X1 , Y1; X2, Y2, ... ]
%%
%%  Example 1:
%%  f(x,y) = y
%%
%%  rhs = @(x,y) y;
%%  df( ode, -2, 2, -2, 2, [ 0,0; 0,1; 0, -0.5] );
%%
%%  Example 2:
%%  f(x,y) = y^2
%%
%%  rhs = @(x,y) y.*y; % NOTE: vectorized multiplication!
%%  df( rhs, -2, 2, -2, 2, [ 0,0; 0,1; 0, -0.5] );
%%
%%  Example 3 (logistic equation):
%%  f(x,y) = r*y*(1-y/K);
%%
%%  logistic = @(x,y,r,K) r*y.*(1-y/K) % NOTE: vectorized multiplication .*
%%  initialData = [ 0,0;0,0.5;0,1;0,1.5;0,2;0,2.5];
%%
%%  r = 0.3; K=2;
%%  df( @(u,t) logistic(u,t,r,K), -4,4,0,3, initialData );
%%
%%  Now the same plot with different parameters
%%  r = 0.1; K=2;
%%  df( @(u,t) logistic(u,t,r,K), -4,4,0,3, initialData );
%%
%%  David Maxwell
%%  January 28, 2010

function df( ode, t0, t1, u0, u1, vals)
  % Number of grid points for the direction field
	Nt=21; Nu=21;

	% Number of steps for solution curves
	NT =100;

	% Create the grid for the direction field
	t=linspace(t0,t1,Nt);
	u=linspace(u0,u1,Nu);
	[tt,uu]=meshgrid(t,u);

  % The vectors to plot at each grid point	
	dt = ones(size(tt));
	du = arrayfun(ode,tt,uu);
	% Normalize the vectors so they mostly fill up their cells
  DT = abs(t1-t0)/Nt; DU = abs(u1-u0)/Nu;
  lam = 3*max(abs(dt)/DT,abs(du)/DU);
	dt = dt./lam;
	du = du./lam;
	
	% Undo any scaling that Octave/Matlab might prefer
  S = 0;

  % Quiver plots vectors with their base at each vertex.  To get centered line segments,
  % use two vectors, each with a base at the vertex, pointing in opposite directions.
  % Plot the first half of each line segment
	h=quiver(tt,uu,dt,du,S);
	set(h,'showarrowhead', 'off');
	
	hold on % (keep drawing on the same figure)
	% Plot the second half of each line segment
	h=quiver(tt,uu,-dt,-du,S);	
	set(h,'showarrowhead', 'off');
	

	% For each initial data, plot the curve corresponding to that solution
	if( exist('vals') )
		nv=size(vals,1);	
		for(k=1:nv)
			tv=vals(k,1);
			uv=vals(k,2);
			dt=(t1-t0)/NT;

			% Compute the solution curve forward in time
			[U,T,istate,msg] = odewalk( ode, tv, uv, dt, t0, t1, u0, u1);
			h=plot(T,U,'color','black');
			set(h,'linewidth',1.5);

			bwode = @(t,u) -ode(-t,u); % ODE for running backwards in time
			[U,T,istate,msg] = odewalk( bwode, -tv, uv, dt, -t1, -t0, u0, u1 );
			h=plot(-T,U,'color','black');
			set(h,'linewidth',1.5);

      % plot(tv,uv,'rs')
		end
	end
	% Constrain the plot window
	axis([t0,t1,u0,u1]);
	% Stop plotting everything in the same window
	hold off
end

function [U,T,istate,msg]=odewalk(ode, tv, uv, dt, t0, t1, u0, u1)

	% Detect if we are running octave or matlab
  Octave = exist('OCTAVE_VERSION') ~= 0;

	% Call the appropriate code
	if( Octave )
		[U,T,istate,msg]=odewalkOctave(ode, tv, uv, dt, t0, t1, u0, u1);
	else
		[U,T,istate,msg]=odewalkMatlab(ode, tv, uv, dt, t0, t1, u0, u1);
	end
end

function [U,T,istate,msg]=odewalkOctave(ode, tv, uv, dt, t0, t1, u0, u1)
	U=[uv];
	T=[tv];
	istate = 0;
	msg = 'initial value out of bounds';

	% Octave uses a different convention for argument order!
	odeOctave = @(u,t) ode(t,u);

	% Step the solution forward until we leave the view window
	% or encounter an error from lsode
%  N=2*(t1-t0)/dt;
  k = 0;
	while( (tv<t1) )
    k = k + 1;
		[UU,istate,msg]=lsode(odeOctave,uv,[tv,tv+dt]);
		if(istate==2)
			U=[U, UU(2)];
			T=[T, tv+dt];
			if( UU(2)<u0 || UU(2)>u1)
				break
			end
	  else
			break
		end
	  tv = tv +dt;
	  uv = UU(2);
	end
end

function [U,T,istate,msg]=odewalkMatlab(ode, tv, uv, dt, t0, t1, u0, u1)
	U=[uv];
	T=[tv];
	% Emulate lsode return codes and pretend like we always succeed
	istate = 2;
	msg = '';
	% Step the solution forward until we leave the view window
	% or encounter an error from lsode
	
	while( tv<t1)
		[TT,UU] = ode45(odeMatlab,[tv,tv+dt],uv);
		U=[U, UU(end)];
		T=[T, TT(end)];
		if( UU(end)<u0 || UU(end)>u1)
				break
		end
		tv = tv+dt;
		uv = UU(end);
	end
end
