function EMstep = def_EMnr(Y,Xf,Xr,mfo,R)

N = mfo.N;
T = mfo.T;
J = mfo.J;

rng(10);
dr_sn = randn(size(Xr,2),R*N);

function [g,H] = calcderiv(pr,qi)
           
    Epr = sum(bsxfun(@times,pr,reshape(qi,[1 N R])),3);    
    g = Xf'*(Y - Epr(:));  
    
    
    Epr2 = bsxfun(@times,reshape(pr,[1 J T N R]),reshape(qi,[1 1 1 N R]));
    Epr2 = sum(bsxfun(@times,reshape(pr,[J 1 T N R]),Epr2),5);
    XEpr2 = sum(bsxfun(@times,reshape(Xf',[],J,1,T,N),reshape(Epr2,[1 J J T N])),2);
    Xpr2X = reshape(XEpr2,[],J*T*N)*Xf;
    X_pr = bsxfun(@times,Xf,Epr(:));
    H = -(Xf'*X_pr - Xpr2X);

end

EMstep = @calc;

function [est,ll,numinner] = calc(est)

    vFC = Xf*est.FC;
    
    beta_i = bsxfun(@plus,est.RCmean,est.cholRCvar*dr_sn);

    vRC = zeros([J*T*N R]);
    for n = 1:N
        vRC((n-1)*J*T+1:n*T*J,:) = Xr((n-1)*J*T+1:n*T*J,:)*beta_i(:,(n-1)*R+1:n*R);
    end
    
    v = bsxfun(@plus,vFC,vRC);
    v = reshape(v,[J T*N*R]);    

    expv = exp(v);
    pr = reshape(bsxfun(@rdivide,expv,sum(expv,1)),[J*T*N R]);
    like = reshape(prod(reshape(pr(Y==1,:),[T N R]),1),[N R]);    
        
    ll = sum(log(mean(like,2)),1);    
    
    qi = bsxfun(@rdivide,like,sum(like,2));
    
    if ~isempty(est.FC)
        numinner = 0;
        pr = reshape(pr,[J*T N R]);
        for m = 1:100
            parms = est.FC;
            [g,H] = calcderiv(pr,qi);
            est.FC = est.FC - (H \ g);
            numinner = numinner + 1;
            if norm(parms-est.FC)<1e-6, break; end
            vFC = Xf*est.FC;
            v = bsxfun(@plus,vFC,vRC);
            v = reshape(v,[J T*N*R]);    
            expv = exp(v);
            pr = reshape(bsxfun(@rdivide,expv,sum(expv,1)),[J*T N R]);            
        end
    else
        numinner = 0;
    end
    
    qi = qi';
    wt_beta =bsxfun(@times,qi(:)',beta_i);  
    est.RCmean = sum(wt_beta,2)./N;
    est.RCvar = (wt_beta*beta_i')./N - est.RCmean*est.RCmean';
    est.cholRCvar = chol(est.RCvar,'lower');
    
end
   
end
