% AUTHOR : DAVID KESSLER, dkessler@live.unc.edu

% Load up a bunch of .mat files and tabulate their results
function magg = dirmspAgg()

LogIt(sprintf('START\n'), 1, 1);
magg.success = 0;

margstrs = {'HOSTNAME', 'LSB_JOBID', 'SCRIPTNAME', ...
            'DATADIR', 'ALIST', ...
            'LOGTAG', ...
            'LOGTHRESH'};

compargs.greeting = 'Hallo, Seemann!';
compargs = getenvargs(compargs, margstrs);

LogIt(dumpstruct(compargs, 0, 3), 1, compargs.LOGTHRESH);

alistArg = compargs.ALISTArg;

% Allow the things in ALIST to be subdirectories with wildcards in files
% separated by '+' that get turned into '\'' so that we can eval

if any(alistArg(:) == '*')
    
    litList = alistArg;
    litList(litList==43) = 39;
    litCell = cell(1,1);
    evstr = strcat('litCell = {', litList, '};');
    eval(evstr);
    
    mlist = [];
    
    for lidx = 1:numel(litCell);
        
        litArg = litCell{lidx};

        LogIt(sprintf('Examining %s\n', litArg), 1, 1);
        
        % Look for DIRECTORY stuff: assume always provided as UNIX "/"
        seps = strfind(litArg, '/');
        
        if ~isempty(seps)
            msubdir = strcat(litArg(1:(seps(end)-1)), filesep);
        else
            msubdir = '';
        end

        % Trouble on UNIX with ls() interpretation
        if ~isempty(compargs.DATADIR)
            mlistA = dir(strcat(compargs.DATADIR, filesep, litArg));
        else
            mlistA = dir(litArg);
        end
        
        for midx = 1:numel(mlistA)
            if ~isempty(compargs.DATADIR)
                mlistA(midx).name = strcat(compargs.DATADIR, filesep, msubdir, mlistA(midx).name);
            else
                mlistA(midx).name = strcat(msubdir, mlistA(midx).name);
            end
        end
        
        if isempty(mlist)
            mlist = mlistA;
        else
            mlist = [mlist ; mlistA];
        end
    end
    
    alist = cell(numel(mlist),1);
    for aidx = 1:length(alist)
        alist{aidx} = mlist(aidx).name;
    end
    
else
    % Assume it is a LIST
    % HACK - replaces all '+' with single quotes, it is a way to get a list of
    % character strings set up so that MATLAB can eval them.
    alistArg(alistArg == 43) = 39;
    alist = '';

    % Set up some some accumulating structures, we'll only compute means over
    % the replicates.
    evstr = strcat('alist = {', alistArg, '};');
    try
        eval(evstr);
    catch ME
        error('Trouble on eval of ALIST');
    end

    if ~isempty(compargs.DATADIR)
        for cidx = 1 : numel(alist)
            alist{cidx} = strcat(compargs.DATADIR, filesep, alist{cidx});
        end
    end
end
    
% Results TABLE will contain
% MAXITER P1ALPHA P3MARGALPHA P3MHALPHA P1MARG P2MARG P3MARG P1JOIN P2JOIN
% P3JOIN p3accrate
reslbls = {'MAXITER', ...
           'NTRAIN', 'SEED', 'NSEED', ...
           'P1ALPHA', ...
           'P2ALPHA', ...
           'P3MARGALPHA', 'P3BASEALPHA', ... 
           'P3ACCRATE', ...
           'P1MARGKL', 'P2MARGKL', 'P3MARGKL', 'P4MARGKL', ...
           'P1MARGKL1', 'P1MARGKL2', 'P1MARGKL3', 'P1MARGKL4', 'P1MARGKL5', ...
           'P2MARGKL1', 'P2MARGKL2', 'P2MARGKL3', 'P2MARGKL4', 'P2MARGKL5', ...
           'P3MARGKL1', 'P3MARGKL2', 'P3MARGKL3', 'P3MARGKL4', 'P3MARGKL5', ...
           'P4MARGKL1', 'P4MARGKL2', 'P4MARGKL3', 'P4MARGKL4', 'P4MARGKL5', ...
           'P1LDFMSEG', 'P2LDFMSEG', 'P3LDFMSEG', 'P4LDFMSEG', ...
           'P1LDFMSE',  'P2LDFMSE',  'P3LDFMSE',  'P4LDFMSE', ...
           'P1JOINKL',  'P2JOINKL',  'P3JOINKL',  'P4JOINKL', ...
};

idxvarlist = '';
for lidx = 1 : numel(reslbls)
    mname = reslbls{lidx};
    idxvarname = strcat(mname, 'idx');
    evstr = strcat(idxvarname, ' = ', sprintf('%d',lidx),';');
    eval(evstr);
    if lidx == 1
        idxvarlist = strcat('+', idxvarname, '+');
    else
        idxvarlist = strcat(idxvarlist, ',+', idxvarname, '+');
    end
end

restbl = zeros(length(alist), length(reslbls));

for fidx = 1:length(alist)
    afile = alist{fidx};

    LogIt(sprintf('Loading %s\n', afile), 1, 1);
    load(afile);
    
    if ~exist('marg','var')
        error('No MARG after load of %s', afile);
    end
    
    if ~exist('mres','var');
        error('No MRES after load of %s', afile);
    end

    restbl(fidx,      MAXITERidx) = marg.MAXITER;
    restbl(fidx,       NTRAINidx) = sum(mres.trainobs);
    restbl(fidx,        NSEEDidx) = marg.NSEED;
    restbl(fidx,         SEEDidx) = marg.SEED;
    restbl(fidx,      P1ALPHAidx) = marg.P1ALPHA;
    restbl(fidx,  P3MARGALPHAidx) = marg.P3MARGALPHA;
    restbl(fidx,  P3BASEALPHAidx) = marg.P3BASEALPHA;
    restbl(fidx,    P3ACCRATEidx) = mres.p3accrate;

    % New stuff that might not be in all results
    if isfield(marg, 'P2ALPHA')
        restbl(fidx, P2ALPHAidx) = marg.P2ALPHA;
        locP2ALPHA = marg.P2ALPHA;
    else
        restbl(fidx, P2ALPHAidx) = 1;
        locP2ALPHA = 1;
    end
    
    % RE-DO ALL COMPUTATIONS according to the design
    
    % For MARGINS
    marger = getmarger(marg.DVEC);

    % For LDFs
    xres = getmargtables(marg.DVEC, mres.truepi);

    % For P1
    fitmat = iterpropfitPdim(mres.truemarg, marg.DVEC);
    fitvec = fitmat(:,end);
    
    % P1
    post1 = mres.mobs + fitvec * marg.P1ALPHA;
    join1 = post1 / sum(post1);
    join1kld = joinKLDM(mres.truepi, join1);
    marg1 = reshape(marger * join1, max(marg.DVEC), numel(marg.DVEC));
    marg1kld = margmeasureKLM(marg.DVEC, mres.truemarg, marg1);
    marg1kldj = zeros(1, numel(marg.DVEC));
    for pidx = 1 : numel(marg.DVEC)
        marg1kldj(pidx) = margmeasureKLM(marg.DVEC(pidx), mres.truemarg(:,pidx), marg1(:,pidx));
    end
    ldf1 = xres.lsummer * log(xres.msummer * join1);
    ldf1mseG = msemeasure(ldf1(:), xres.ldfstack(:));
    ldf1mse = ldfmsemeasure(ldf1(:), xres.ldfstack(:), xres.ldffidx);
    
    % P2
    post2 = mres.mobs + locP2ALPHA;
    join2 = post2 / sum(post2);
    join2kld = joinKLDM(mres.truepi, join2);
    marg2 = reshape(marger * join2, max(marg.DVEC), numel(marg.DVEC));
    marg2kld = margmeasureKLM(marg.DVEC, mres.truemarg, marg2);
    marg2kldj = zeros(1, numel(marg.DVEC));
    for pidx = 1 : numel(marg.DVEC)
        marg2kldj(pidx) = margmeasureKLM(marg.DVEC(pidx), mres.truemarg(:,pidx), marg2(:,pidx));
    end
    ldf2 = xres.lsummer * log(xres.msummer * join2);
    ldf2mseG = msemeasure(ldf2(:), xres.ldfstack(:));
    ldf2mse = ldfmsemeasure(ldf2(:), xres.ldfstack(:), xres.ldffidx);

    % P3
    post3 = mres.post3;
    join3 = post3 / sum(post3);
    join3kld = joinKLDM(mres.truepi, join3);
    marg3 = reshape(marger * join3, max(marg.DVEC), numel(marg.DVEC));
    marg3kld = margmeasureKLM(marg.DVEC, mres.truemarg, marg3);
    marg3kldj = zeros(1, numel(marg.DVEC));
    for pidx = 1 : numel(marg.DVEC)
        marg3kldj(pidx) = margmeasureKLM(marg.DVEC(pidx), mres.truemarg(:,pidx), marg3(:,pidx));
    end
    ldf3 = xres.lsummer * log(xres.msummer * join3);
    ldf3mseG = msemeasure(ldf3(:), xres.ldfstack(:));
    ldf3mse = ldfmsemeasure(ldf3(:), xres.ldfstack(:), xres.ldffidx);
    
    % P4
    post4 = mres.post4;
    join4 = post4 / sum(post4);
    join4kld = joinKLDM(mres.truepi, join4);
    marg4 = reshape(marger * join4, max(marg.DVEC), numel(marg.DVEC));
    marg4kld = margmeasureKLM(marg.DVEC, mres.truemarg, marg4);
    marg4kldj = zeros(1, numel(marg.DVEC));
    for pidx = 1 : numel(marg.DVEC)
        marg4kldj(pidx) = margmeasureKLM(marg.DVEC(pidx), mres.truemarg(:,pidx), marg4(:,pidx));
    end
    ldf4 = xres.lsummer * log(xres.msummer * join4);
    ldf4mseG = msemeasure(ldf4(:), xres.ldfstack(:));
    ldf4mse = ldfmsemeasure(ldf4(:), xres.ldfstack(:), xres.ldffidx);

    restbl(fidx, [P1MARGKLidx P2MARGKLidx P3MARGKLidx P4MARGKLidx]) = [marg1kld marg2kld marg3kld marg4kld];
    restbl(fidx, [P1LDFMSEGidx P2LDFMSEGidx P3LDFMSEGidx P4LDFMSEGidx]) = [ldf1mseG ldf2mseG ldf3mseG ldf4mseG];
    restbl(fidx, [P1LDFMSEidx  P2LDFMSEidx  P3LDFMSEidx P4LDFMSEidx])  = [ldf1mse  ldf2mse  ldf3mse  ldf4mse];
    restbl(fidx, [P1JOINKLidx P2JOINKLidx P3JOINKLidx P4JOINKLidx]) = [join1kld join2kld join3kld join4kld];
    restbl(fidx, P1MARGKL1idx : P1MARGKL5idx) = marg1kldj;
    restbl(fidx, P2MARGKL1idx : P2MARGKL5idx) = marg2kldj;
    restbl(fidx, P3MARGKL1idx : P3MARGKL5idx) = marg3kldj;
    restbl(fidx, P4MARGKL1idx : P4MARGKL5idx) = marg4kldj;
end    

% Get the base and save the results
fname = getenv('OUTPNAME');
if isempty(fname)
    bname = strcat(getenv('TMP'), filesep, 'dirmspAgg');
elseif strcmpi(fname((end-2):end),'txt') == 1
    bname = fname(1:(end-4));
else
    bname = fname;
end

setenv('OUTBNAME', bname);

LogIt(sprintf('Saving results to %s.mat\n', bname), 1, compargs.LOGTHRESH);
evstr = strcat('save(bname, +alist+, +reslbls+, +restbl+, +compargs+,', idxvarlist, ');');
evstr(evstr==43) = 39;
eval(evstr);
% save(bname, 'alist', 'reslbls', 'restbl', 'compargs');

end

function mkld = margmeasure(DVEC, truemarg, estmarg)

mkldvec = zeros(1,numel(DVEC));

% Go through the margins and compute
for kidx = 1 : numel(mkldvec)
    mkldvec(kidx) = dirkld(truemarg(1:DVEC(kidx),kidx), estmarg(1:DVEC(kidx),kidx));
end

mkld = mean(mkldvec .* mkldvec);

end

function mkld = margmeasurefull(DVEC, nfull, truepi, postvec)

mkldvec = zeros(1,numel(DVEC));
marger = getmarger(DVEC);

truemarg = nfull * reshape(marger * truepi, max(DVEC), numel(DVEC));
estmarg = reshape(marger * postvec, max(DVEC), numel(DVEC));

% Go through the margins and compute
for kidx = 1 : numel(mkldvec)
    mkldvec(kidx) = dirkld(truemarg(1:DVEC(kidx),kidx), estmarg(1:DVEC(kidx),kidx));
end

mkld = mean(mkldvec .* mkldvec);

end

function mmse = msemeasure(truequant, estquant)

mdiff = truequant - estquant;

mmse = mean(mdiff .* mdiff);

end

% ------------------------------------------------------------------------
% Simple Multinomial K-L divergence.  That's all we're measuring.
% Do one for each margin and then get the MSE for the lot
% TRICKERY in case one of the estimates is zero,
% ASSUMES all of the true entries are > 0
function mkld = margmeasureKLM(DVEC, truemarg, estmarg)

mkldvec = zeros(1,numel(DVEC));

% Go through the margins and compute
for kidx = 1 : numel(mkldvec)
    piT = truemarg(1:DVEC(kidx), kidx);
    piE = estmarg(1:DVEC(kidx), kidx);
    mkldvec(kidx) = sum( piT .* (log(piT) - log(piE)) );
end

mkld = mean(mkldvec .* mkldvec);

end

function mse = ldfmsemeasure(ldfest, ldftrue, tableids)

% Go through the distinct table ids and compute an MSE for each
utbls = unique(tableids);
mses = zeros(1, numel(utbls));

for uidx = 1:numel(utbls)
    selidx = tableids==utbls(uidx);
    mses(uidx) = msemeasure(ldftrue(selidx), ldfest(selidx));
end

mse = mean(mses);

end

function mkld = joinKLDM(jointrue, joinest)

lclest = joinest;
if (any(joinest <= 0))
    lclest(lclest <= 0) = min(realmin, min(joinest(joinest>0)));
end
    
mkld = jointrue(:).' * (log(jointrue(:)) - log(lclest(:)));

end
