function EMstep = def_EMqn(Y,Xf,Xr,mfo,R)

N = mfo.N;
T = mfo.T;
J = mfo.J;

rng(10);
dr_sn = randn(size(Xr,2),R*N);

o1 = optimoptions(@fminunc,'Display','off','Algorithm','Quasi-newton','GradObj','on','DerivativeCheck','off');
function [ll,g] = calclike(parms,vRC,qi)
    
    vFC = Xf*parms;
    v = bsxfun(@plus,vFC,vRC);
    v = reshape(v,[J T*N*R]);    

    expv = exp(v);
    pr = reshape(bsxfun(@rdivide,expv,sum(expv,1)),[J*T N R]);
    
    Elogpr = sum(bsxfun(@times,log(pr),reshape(qi,[1 N R])),3);
    ll = -Y'*Elogpr(:);
   
    Epr = sum(bsxfun(@times,pr,reshape(qi,[1 N R])),3);    
    g = -Xf'*(Y - Epr(:));  
     
end

EMstep = @calc;

function [est,ll,numinner] = calc(est)

    vFC = Xf*est.FC;
    
    beta_i = bsxfun(@plus,est.RCmean,est.cholRCvar*dr_sn);

    vRC = zeros([J*T*N R]);
    for n = 1:N
        vRC((n-1)*J*T+1:n*T*J,:) = Xr((n-1)*J*T+1:n*T*J,:)*beta_i(:,(n-1)*R+1:n*R);
    end
    
    v = bsxfun(@plus,vFC,vRC);
    v = reshape(v,[J T*N*R]);    

    expv = exp(v);
    pr = reshape(bsxfun(@rdivide,expv,sum(expv,1)),[J*T*N R]);
    like = reshape(prod(reshape(pr(Y==1,:),[T N R]),1),[N R]);    
        
    ll = sum(log(mean(like,2)),1);    
    
    qi = bsxfun(@rdivide,like,sum(like,2));
    
    if ~isempty(est.FC)
        [est.FC,~,~,mleoutput] = fminunc(@calclike,est.FC,o1,vRC,qi);
        numinner = mleoutput.funcCount;
    else
        numinner = 0;
    end
    
    qi = qi';
    wt_beta =bsxfun(@times,qi(:)',beta_i);  
    est.RCmean = sum(wt_beta,2)./N;
    est.RCvar = (wt_beta*beta_i')./N - est.RCmean*est.RCmean';
    est.cholRCvar = chol(est.RCvar,'lower');
    
end
   
end
