* q: dynamic rank (number of common shocks to factors)
* r: static rank (r>=q) (number of factors)
* p: ar order of the state vector (default p=1) (for VAR of factors)

* ************************************************************************
* currently when p is not 1 or q is not 1 the program will return an error
* ************************************************************************

* The model
* x_t = C F_t + \xi_t
* F_t = AF_{t-1} + B u_t
* R = E(\xi_t \xi_t')
* Q = BB'
* u_t ~ WN(0,I_q)
* initx = F_0
* initV = E(F_0 F_0')
* ss: std(x) 
* MM: mean(x)
* F : estimated factors
* VF: estimation variance for the common factors

program FactorExtraction
	version 12
	syntax varlist [if] [in] [, q(integer 1) r(integer 2) p(integer 1) name(name) date(varname)]
	* example command
	* FactorExtraction stdx1 stdx2 stdx3 stdx4, q(1) r(2) p(1) name(estf)
	
	if "`name'"=="" {
		local name "EstFactor"
	}

	qui {
		* marksample touse
		* generate id variable to match estimated factors and original data set
		tempvar myid
		gen `myid'=_n
		* proceed with factor extraction in mata
		* if date specified, put only with nonmissing dates
		if "`date'"!="" {
			local cond = "if `date'!=."
		}
		putmata x=(`varlist') myid=`myid' `cond', replace
		* putmata x=(`varlist') myid=`myid' if `touse', replace
		noi mata: FactorExtraction(x, `q', `r', `p')
		mata: F=fe1
		mata: VF=fe2
		mata: A=fe3
		mata: C=fe4
		mata: st_matrix("A", A)
		mata: st_matrix("C", C)
		forvalues varid = 1/`r' {
			cap drop `name'`r'
			mata: TP`varid'=F[.,`varid']
			getmata `name'`varid'=TP`varid', id(`myid'=myid) replace
		}	
	}

end

version 12

mata:

	function LogShow (text) 
	{
		printf("[")
		stata("display as text c(current_time) _continue")
		printf("]")
		text
	}

	function FactorExtraction(|x,q,r,p,A,C,Q,R,initx,initV,ss,MM)
	{
		LogShow("Starting factor extraction")
		LogShow("Setting initial values")
		T = rows(x)
		N = cols(x)
		das = colsum(x:==.)  
		m = max(das)
		ttt = T-m
		z = x[1..ttt,.]
		ss = diagonal(variance(z):^0.5)'
		MM = mean(z)

		if (args() < 5) {  
			LogShow("Calculating model parameters initially")
			ttt = T-m
			z = x[1..ttt,.] 
			ss = diagonal(variance(z):^0.5)'	
			MM = mean(z)
			s = J(T,1,1)*ss
			M = J(T,1,1)*MM
			x = (x - M):/s	
			z = x[1..ttt,.]
			
			LogShow("Calling ricSW")
			ricSW(z,q,r,p)
			LogShow("Completed ricSW")
			ricsw1 = findexternal("ricsw1")
			A = *ricsw1
			ricsw2 = findexternal("ricsw2")
			C = *ricsw2
			ricsw3 = findexternal("ricsw3")
			Q = *ricsw3
			ricsw4 = findexternal("ricsw4")
			R = *ricsw4
			ricsw5 = findexternal("ricsw5")
			initx=*ricsw5
			ricsw6 = findexternal("ricsw6")
			initV=*ricsw6	
		}	
		else { 
			LogShow("Taking given model parameters")
			ttt = T-m
			s = J(T,1,1)*ss
			M = J(T,1,1)*MM
			x = (x - M):/s	
			z = x[1..ttt,.]
		}
		LogShow("Preparing Kalman filter")
		AA = asarray_create("real", 1)
		QQ = asarray_create("real", 1)
		CC = asarray_create("real", 1)
		RR = asarray_create("real", 1)
		for (jt=1; jt<=T; jt++) {
			asarray(AA,jt,A)
			asarray(QQ,jt,Q)
			asarray(CC,jt,C)
			miss = colmissing(x[jt,.])
			Rtemp = diagonal(R)
			Rtemp=Rtemp+miss':*1e+32
			asarray(RR,jt,diag(Rtemp))
		}
		xx= x
		_editmissing(xx,0)
		ott = (1..T)
		LogShow("Calling Kalman smoother diag")
		kalman_smoother_diag(xx',AA, CC, QQ, RR, initx, initV, ott)
		LogShow("Complted Kalman smoother diag")
		ksd1 = findexternal("ksd1")
		xsmooth = *ksd1
		ksd2 = findexternal("ksd2")
		Vsmooth = *ksd2
		ksd3 = findexternal("ksd3")
		VVsmooth = *ksd3
		ksd4 = findexternal("ksd4")
		loglik = *ksd4
		VF = Vsmooth
		ind = asarray_elements(VF)
		F =  xsmooth'
		LogShow("Returning results to Stata")
		if ((fe1 = findexternal("fe1")) == NULL) fe1 = crexternal("fe1")
		*fe1 = F
		if ((fe2 = findexternal("fe2")) == NULL) fe2 = crexternal("fe2")
		*fe2 = VF
		if ((fe3 = findexternal("fe3")) == NULL) fe3 = crexternal("fe3")
		*fe3 = A
		if ((fe4 = findexternal("fe4")) == NULL) fe4 = crexternal("fe4")
		*fe4 = C
		if ((fe5 = findexternal("fe5")) == NULL) fe5 = crexternal("fe5")
		*fe5 = Q
		if ((fe6 = findexternal("fe6")) == NULL) fe6 = crexternal("fe6")
		*fe6 = R
		if ((fe7 = findexternal("fe7")) == NULL) fe7 = crexternal("fe7")
		*fe7 = initx
		if ((fe8 = findexternal("fe8")) == NULL) fe8 = crexternal("fe8")
		*fe8 = initV
		if ((fe9 = findexternal("fe9")) == NULL) fe9 = crexternal("fe9")
		*fe9 = ss
		if ((fe10 = findexternal("fe10")) == NULL) fe10 = crexternal("fe10")
		*fe10 = MM
		LogShow("Factor extraction completed")
		LogShow("Return to Stata")
	}

	function ricSW(|x,q,r,p)
	{
		LogShow("Start ricSW")
		Mx = mean(x)
		Wx = diag(diagonal(variance(x):^0.5)')
		TTT = rows(x)
		nnn = cols(x)
		xc = x - J(TTT,1,1)*(colsum(x)/TTT)
		x = xc*invsym(Wx) 
		
		T = rows(x)
		N = cols(x)

		nlag = p-1	
		A_temp = J(r,r*p,0)'
		I = I(r*p,r*p)
		eerr = r*p-r
		eeee = r*p
		if (p==1) {
			A = A_temp
		}
		else {
			A = (A_temp'\I[(1..eerr),(1..eeee)])
		}	
		Q = J(r*p,r*p,0)
		Q[(1..r),(1..r)] = I(r)

		v=.
		d=.
		eigensystem(variance(x), v, d)
		d=diag(d)
		v=v[.,(1..r)]
		d=d[.,(1..r)]
		
		F = x*v
		R = diag(diagonal(variance(Re(x-x*v*v'))))
		
		LogShow("Start initial value calculation")
		if (p>0) {
			LogShow("Now that p>0... Calculate initial values")
			z = F
			Z = J(0,0,.)
			szr=rows(z)
			szc=cols(z)
			LogShow("Starting for loop over all values of p")
			for (kk=1; kk<=p; kk++) {
				idr1=p-kk+1
				idr2=szr-kk
				adrow=idr2-idr1+1
				if (kk==1) {
					szzr=rows(Z)
					Z=(Z, J(szzr,szc,.))
					szzc=cols(Z)
					Z=(Z\J(adrow,szzc,.))
					szzr=rows(Z)
				}
				startzr = szzr-adrow+1
				startzc = szzc-szc+1
				if (kk==1) {
					Z[(startzr..szzr), (startzc..szzc)]=Re(z[(idr1..idr2),.])	
				}
				else {
					Z = (Z, Re(z[(idr1..idr2),.]) )
				}	
			}
			LogShow("Done looping over values of p")
			z = z[(p+1..szr),.]
			A_temp = invsym(Z'*Z)*Z'*z 
			A[(1..r),(1..r*p)] = Re(A_temp')
			e = z - Z*A_temp
			H = variance(Re(e)) 
			if (r==q) {
				Q[(1..r),(1..r)] = H
			}
			else { 
				P=.
				M=.
				eigensystem(H, P, M)
				M=diag(M)
				P=P[.,(1..q)]
				M=M[(1..q),(1..q)]
				P = P*diag(sign(Re(P[1,.])))
				u_orth = e*P*(M:^(-0.5))
				e_pc = e*P*P'
				Q[(1..r),(1..r)] = Re(P*M*P')
			}
		}
		LogShow("Second stage")
		if (Re(p)>0) {
			z = F
			Z = J(0,0,.)
			szr=rows(z)
			szc=cols(z)
			for (kk=0; kk<=nlag; kk++) {
				idr1=nlag-kk+1
				idr2=szr-kk
				adrow=idr2-idr1+1
				if (kk==0) {
					szzr=rows(Z)
					Z=(Z, J(szzr,szc,.))
					szzc=cols(Z)
					Z=(Z\J(adrow,szzc,.))
					szzr=rows(Z)
				}
				startzr = szzr-adrow+1
				startzc = szzc-szc+1
				if (kk==0) {
					Z[(startzr..szzr), (startzc..szzc)]=Re(z[(idr1..idr2),.])	
				}
				else {
					Z = (Z, Re(z[(idr1..idr2),.]) )
				}	
			}
			initx = Z[1,.]'
			
			raa=rows(A#A)
			initV = rowshape((pinv(I(raa)-(A#A))*vec(Q))',r*p)
		}
		else {
			initx = J(0,0,.)
			initV = J(0,0,.)
		}
		LogShow("Pareparing results to return to parent process")
		C = (v, J(N,r*(nlag),0))	
		if ((ricsw1 = findexternal("ricsw1")) == NULL) ricsw1 = crexternal("ricsw1")
		*ricsw1 = A
		if ((ricsw2 = findexternal("ricsw2")) == NULL) ricsw2 = crexternal("ricsw2")
		*ricsw2 = Re(C)
		if ((ricsw3 = findexternal("ricsw3")) == NULL) ricsw3 = crexternal("ricsw3")
		*ricsw3 = Q
		if ((ricsw4 = findexternal("ricsw4")) == NULL) ricsw4 = crexternal("ricsw4")
		*ricsw4 = R
		if ((ricsw5 = findexternal("ricsw5")) == NULL) ricsw5 = crexternal("ricsw5")
		*ricsw5 = initx
		if ((ricsw6 = findexternal("ricsw6")) == NULL) ricsw6 = crexternal("ricsw6")
		*ricsw6 = initV
		if ((ricsw7 = findexternal("ricsw7")) == NULL) ricsw7 = crexternal("ricsw7")
		*ricsw7 = Mx
		if ((ricsw8 = findexternal("ricsw8")) == NULL) ricsw8 = crexternal("ricsw8")
		*ricsw8 = Wx
		LogShow("ricSW completed")
	}

	function kalman_smoother_diag(|y, A, C, Q, R, init_x, init_V, model)
	{
		LogShow("Kalman smoother diag start")
		os = rows(y)
		T = cols(y)
		ss = rows(asarray(A,1))
		// model = J(1,T,1)
		// u = J(0,0,.)
		// B = J(0,0,.)

		xsmooth = J(ss,T,0)
		Vsmooth = asarray_create("real", 1)
		VVsmooth = asarray_create("real", 1)
		
		LogShow("Call Kalman filter diag")
	
		kalman_filter_diag(y, A, C, Q, R, init_x, init_V, model, u, B)
		kfd1 = findexternal("kfd1")
		xfilt = *kfd1
		kfd2 = findexternal("kfd2")
		Vfilt = *kfd2
		kfd3 = findexternal("kfd3")
		VVfilt = *kfd3
		kfd4 = findexternal("kfd4")
		loglik = *kfd4
		LogShow("Done with Kalman filter diag")
		
		xsmooth[.,T] = xfilt[.,T]
		asarray(Vsmooth,T,asarray(Vfilt,T))
		
		LogShow("Looping over all time periods")
		for (ss=1; ss<=T-1; ss++) {
			t=T-ss
			m = model[t+1]
			if (cols(B)==0|rows(B)==0) {
				smooth_update(xsmooth[.,t+1], asarray(Vsmooth,t+1), xfilt[.,t], asarray(Vfilt,t), asarray(Vfilt,t+1), asarray(VVfilt,t+1), asarray(A,m), asarray(Q,m), J(0,0,.), J(0,0,.))
				su1 = findexternal("su1")
				su2 = findexternal("su2")
				su3 = findexternal("su3")
				xsmooth[.,t]=*su1
				asarray(Vsmooth,t,*su2)
				asarray(VVsmooth,t+1,*su3)
			}
			else {
				smooth_update(xsmooth[.,t+1], asarray(Vsmooth,t+1), xfilt[.,t], asarray(Vfilt,t), asarray(Vfilt,t+1), asarray(VVfilt,t+1), asarray(A,m), asarray(Q,m), asarray(B,m), u[.,t+1])
				su1 = findexternal("su1")
				su2 = findexternal("su2")
				su3 = findexternal("su3")
				xsmooth[.,t]=*su1
				asarray(Vsmooth,t,*su2)
				asarray(VVsmooth,t+1,*su3)
			}
		}
		LogShow("Done with looping over all time periods")
		
		asarray(VVsmooth,1,J(ss,ss,0))
		
		if ((ksd1 = findexternal("ksd1")) == NULL) ksd1 = crexternal("ksd1")
		*ksd1 = xsmooth
		if ((ksd2 = findexternal("ksd2")) == NULL) ksd2 = crexternal("ksd2")
		*ksd2 = Vsmooth
		if ((ksd3 = findexternal("ksd3")) == NULL) ksd3 = crexternal("ksd3")
		*ksd3 = VVsmooth
		if ((ksd4 = findexternal("ksd4")) == NULL) ksd4 = crexternal("ksd4")
		*ksd4 = loglik
		LogShow("Kalman smoother diag complete")
	}

	function kalman_filter_diag(|y, A, C, Q, R, init_x, init_V, model, u, B, ndx)
	{
		LogShow("Start Kalman filter diag")
		os=rows(y)
		T=cols(y)
		ss=rows(asarray(A,1))
		// model = J(1,T,1)

		ndx = J(0,0,.)

		x = J(ss, T, 0)
		V = asarray_create("real", 1)
		VV = asarray_create("real", 1)

		loglik = 0
		for (t=1; t<=T; t++) {
			m = model[t]
			if (t==1) {
				prevx = init_x
				prevV = init_V
				initial = 1
			}
			else {
				prevx = x[.,t-1]
				prevV = asarray(V, t-1)
				initial = 0
			}
			if (rows(u)==0|cols(u)==0) {
				kalman_update_diag(asarray(A,m), asarray(C,m), asarray(Q,m), asarray(R,m), y[.,t], prevx, prevV, initial)
				kud1 = findexternal("kud1")
				kud2 = findexternal("kud2")
				kud3 = findexternal("kud3")
				kud4 = findexternal("kud4")
				x[.,t]=*kud1
				asarray(V,t,*kud2)
				LL=*kud3
				asarray(VV,t,*kud4)
			}
			else {
				if (rows(ndx)==0|cols(ndx)==0) {
					kalman_update_diag(asarray(A,m), asarray(C,m), asarray(Q,m), asarray(R,m), y[.,t], prevx, prevV, initial, u[.,t], asarray(B,m))
					kud1 = findexternal("kud1")
					kud2 = findexternal("kud2")
					kud3 = findexternal("kud3")
					kud4 = findexternal("kud4")
					x[.,t]=*kud1
					asarray(V,t,*kud2)
					LL=*kud3
					asarray(VV,t,*kud4)
				}
				else {
					i = ndx[t]
					x[.,t] = prevx
					prevP = invsym(prevV)
					prevPsmall = prevP[i,i]
					prevVsmall = invsym(prevPsmall)
					kalman_update_diag(asarray(A,m)[i,i], asarray(C,m)[.,i], asarray(Q,m)[i,i], asarray(R,m), y[.,t], prevx[i], prevVsmall, initial, u[.,t], asarray(B,m)[i,.])
					kud1 = findexternal("kud1")
					kud2 = findexternal("kud2")
					kud3 = findexternal("kud3")
					kud4 = findexternal("kud4")
					x[i,t]=*kud1
					smallV=*kud2
					LL=*kud3
					tmptmp=asarray(VV,t)
					tmptmp[i,i]=*kud4
					asarray(VV,t,tmptmp)
					smallP = invsym(smallV)
					prevP[i,i] = smallP
					asarray(V,t,invsym(prevP))
				}    
			}
			loglik = loglik + LL
		}
		
		if ((kfd1 = findexternal("kfd1")) == NULL) kfd1 = crexternal("kfd1")
		*kfd1 = x
		if ((kfd2 = findexternal("kfd2")) == NULL) kfd2 = crexternal("kfd2")
		*kfd2 = V
		if ((kfd3 = findexternal("kfd3")) == NULL) kfd3 = crexternal("kfd3")
		*kfd3 = VV
		if ((kfd4 = findexternal("kfd4")) == NULL) kfd4 = crexternal("kfd4")
		*kfd4 = loglik
	}

	function kalman_update_diag(|A, C, Q, R, y, x, V, initial, u, B) 
	{
		u = J(0,0,.)
		B = J(0,0,.)
		initial = 0
		if (initial==1) {
			if (rows(u)==0|cols(u)==0) {
				xpred = x
			}
			else {
				xpred = x + B*u
			}
			Vpred = V
		}
		else {
			if (rows(u)==0|cols(u)==0) {
				xpred = A*x
			}	
			else {
				xpred = A*x + B*u
			}		
			Vpred = A*V*A' + Q
		}
		e = y - C*xpred
		n = length(e)
		ss = max((rows(A),cols(A)))
		
		d = rows(e)
		S = C*Vpred*C' + R
		GG = C'*diag(1:/diagonal(R))*C
				
		Sinv = diag(1:/diagonal(R)) - diag(1:/diagonal(R))*C*pinv(I(ss)+Vpred*GG)*Vpred*C'*diag(1:/diagonal(R))
		nnn=length(diagonal(R))
		if (nnn==1) {
			proddr=diagonal(R)[1]
		}
		else {
			proddr=diagonal(R)[1]
			for (nnid=2; nnid<=nnn; nnid++) {
				proddr=proddr*diagonal(R)[nnid]
			}
		}
		detS = proddr*det(I(ss)+Vpred*GG)
		denom = (2*3.141592653589793)^(d/2)*sqrt(abs(detS))
		mahal = colsum(e'*Sinv*e,2)
		loglik = -0.5*mahal - log(denom)
		K = Vpred*C'*Sinv 
		xnew = xpred + K*e               
		Vnew = (I(ss) - K*C)*Vpred    
		VVnew = (I(ss) - K*C)*A*V

		if ((kud1 = findexternal("kud1")) == NULL) kud1 = crexternal("kud1")
		*kud1 = xnew
		if ((kud2 = findexternal("kud2")) == NULL) kud2 = crexternal("kud2")
		*kud2 = Vnew
		if ((kud3 = findexternal("kud3")) == NULL) kud3 = crexternal("kud3")
		*kud3 = loglik
		if ((kud4 = findexternal("kud4")) == NULL) kud4 = crexternal("kud4")
		*kud4 = VVnew
	}	

	function smooth_update(|xsmooth_future, Vsmooth_future, xfilt, Vfilt,  Vfilt_future, VVfilt_future, A, Q, B, u)
	{
		if (rows(B)==0|cols(B)==0) {
			xpred = A*xfilt
		}
		else {
			xpred = A*xfilt + B*u
		}
		Vpred = A*Vfilt*A' + Q
		J = Vfilt*A'*pinv(Vpred)
		xsmooth = xfilt + J*(xsmooth_future - xpred)
		Vsmooth = Vfilt + J*(Vsmooth_future - Vpred)*J'
		VVsmooth_future = VVfilt_future + (Vsmooth_future - Vfilt_future)*pinv(Vfilt_future)*VVfilt_future

		if ((su1 = findexternal("su1")) == NULL) su1 = crexternal("su1")
		*su1 = xsmooth
		if ((su2 = findexternal("su2")) == NULL) su2 = crexternal("su2")
		*su2 = Vsmooth
		if ((su3 = findexternal("su3")) == NULL) su3 = crexternal("su3")
		*su3 = VVsmooth_future	
	}

end

