% ------------------------------------------------------------------------
% dirmsp
% AUTHOR : DAVID KESSLER, dkessler@live.unc.edu
%
% Dirichlet priors, marginally specified priors
%
% 1: Dirichlet distribution with an attempted informative structure
% 2: Standard Dirichlet distribution
% 3: Standard Dirichlet distribution with marginals replaced
%
% 1: one-step, closed form.
% 2: one-step, closed form.
% 3: M-H
%
% MSETUPSCRIPT = getdirmspargs
% MWORKSCRIPT = dirmsp
% Generalized to unequal d_j values
% And to include marginal data sets
% returns nothing - saves output in a .MAT file
% ------------------------------------------------------------------------
function mres = dirmsp(marg)

LogIt(sprintf('%s START\n', marg.logstr), 1, marg.LOGTHRESH);

% Add some logging info
mrevstr = '$Rev: 1251 $';
mrevstr = mrevstr(7:(end-2));
LogIt(sprintf('%s REVISION=dirmsp:%s\n', marg.logstr, mrevstr), 1, marg.LOGTHRESH);
marg.REVSTR = mrevstr;

% Environment must describe DVEC, since observations may not cover all
% categories and marginal prior may not be in use

% SEEDS
marg = setAndLogSeeds(marg, 1);

% ----------------------------------------------------------------------
% BOILERPLATE

% Set up a summing matrix
marger = getmarger(marg.DVEC);

% ----------------------------------------------------------------------
% "TRUTH" dataset for comparing LAMBDA-MAX
mobs = [];
eval(strcat(marg.TRUEDATAFILE, ';'));
if ~exist('mobs', 'var')
    error('%s did not contain mobs variable', marg.TRUEDATAFILE);
end
trueobs = mobs; clear mobs;
truevec = vec2idx(trueobs, marg.DVEC);

% What it really is
mres.truecount = histc(truevec, 1:prod(marg.DVEC));
mres.obspi = mres.truecount / sum(mres.truecount);

% Add small fractional counts so we get probability everywhere.
mres.truepi = mres.obspi;
minpi = min(mres.truepi(mres.truepi>0));
mres.truepi(mres.truepi==0) = minpi/10;
mres.truepi = mres.truepi / sum(mres.truepi);

% ----------------------------------------------------------------------
% ALWAYS LOAD CASE DATA from somewhere, e.g. ACS_2010_PUMS
%
% NOW: if we've indicated that we want to draw a sample from the truth
% to use as our "trainobs", do so

if marg.CASEFROMTRUE == 1
    LogIt(sprintf('Generating cases from true set, NTRAIN=%d\n', marg.NTRAIN), 1, 1);
    trainobs = randsample(1:prod(marg.DVEC), marg.NTRAIN, true, mres.truepi);
    stackobs = histc(trainobs', 1:prod(marg.DVEC));
else
    if isempty(marg.CASEDATAFILE)
        error('No CASEDATAFILE specified');
    end

    % eval the filename and look for "mobs"
    mobs = []; clear mobs;  % stupid MATLAB placation
    evstr = strcat(marg.CASEDATAFILE,';');
    try
        eval(evstr);
    catch ME
        error('Could not source %s : %s', marg.CASEDATAFILE, ME.message);
    end

    if ~exist('mobs', 'var')
        error('%s did not contain mobs variable', marg.CASEDATAFILE);
    end

    trainobs = mobs; clear mobs;
    marg.NTRAIN = size(trainobs,1);
    LogIt(sprintf('Resetting marg.NTRAIN = %d\n', marg.NTRAIN), 1, 1);

    % Convert case-level data vectors into indices into the monster vector
    monstidx = vec2idx(trainobs, marg.DVEC);

    % Our "observation" is the histogram of the cell indices
    stackobs = histc(monstidx, 1:prod(marg.DVEC));
end

% ------------------------------------------------------------------------
% ALWAYS LOAD MARGINAL DATA from somewhere
% e.g., ACS_2010_MARG
% ------------------------------------------------------------------------
if  marg.MARGFROMTRUE == 1
    truedist = reshape(marger * mres.truepi, max(marg.DVEC), numel(marg.DVEC));
    truemarg = truedist;
else
    if isempty(marg.MARGDATAFILE)
        error('MARGDATAFILE not specified for GENMARGFROM');
    end

    evstr = strcat(marg.MARGDATAFILE, ';');
    try
        eval(evstr);
    catch ME
        error('Could not load MARGDATAFILE=[%s] : %s', marg.MARGDATAFILE, ME.message);
    end

    % placate M-Lint
    truedist = []; clear truedist;
    
    % File must define the 'truedist' variable
    if ~exist('truedist', 'var')
        error('MARGDATAFILE=[%s] did not define truedist', marg.MARGDATAFILE);
    end

    % At this point, truemarg is defined - must satisfy DVEC as above
    truemarg = truedist ./ kron(sum(truedist,1), ones(size(truedist,1),1));
end

mres.mobs = stackobs;

% ------------------------------------------------------------------------
% THE PRIORS
% ------------------------------------------------------------------------

% Prior 1: Dirichlet prior made informative w/marginal info : CONJUGATE
fitmat = iterpropfitPdim(truedist, marg.DVEC);
fitprop = fitmat(:,end) / sum(fitmat(:,end));
mres.post1 = stackobs + fitprop * marg.P1ALPHA;

% Prior 2: "Uninformative" Dirichlet prior, CONJUGATE
% the "ones" mean that the prior is extremely informative about the margins
% and probably clobber any information in the data
% if P2ALPHA == 1, this is the classic "uninformative" Dirichlet prior
% O/W you get a symmetric Dirichlet of differing information...
mres.post2 = stackobs + ones(size(stackobs)) * marg.P2ALPHA;

% Prior 3: "Uninformative" Dirichlet prior with margins replaced
% M-H process as outlined by PDH and modified by DCK
% mres.post3info = dirmspMH(marg, stackobs, truedist);
% mres.post3 = mres.post3info.postmean';
% M-H process as outlined by DBD and implemented by DCK
% post3info = dirmspMHhyb(marg, stackobs, fitprop, truedist);
% M-H process as outlined by DBD on 2012-02-15 and implemented by DCK
if strcmpi(marg.P3METHOD, 'SINGLEOPT')
    post3info = dirmspMHoptR00(marg, stackobs, truedist);
elseif strcmpi(marg.P3METHOD, 'MULTIOPT')
    post3info = dirmspMHoptR01(marg, stackobs, truedist, stackobs + fitprop * marg.P3MHALPHA);
elseif strcmpi(marg.P3METHOD, 'SEQOPT')
    post3info = dirmspMHseqopt(marg, stackobs, truedist);
    mres.post3sample = post3info.effsamp;
elseif strcmpi(marg.P3METHOD, 'SEQBASE')
    % final argument is used only if marg.P3MHSOURCE == 1
    post3info = dirmspMHseqbase(marg, stackobs, truedist, fitprop * marg.P3MHALPHA);
    mres.post3sample = post3info.effsamp;
elseif strcmpi(marg.P3METHOD, 'GAMMAMH')
    post3info = dirmspMHblkgam(marg, stackobs, truedist);
    mres.post3sample = post3info.effsamp;
elseif strcmpi(marg.P3METHOD, 'SUBDIR')
    post3info = dirmspMHsubdir(marg, stackobs, truedist);
    mres.post3sample = post3info.effsamp;
    mres.post3thetasample = post3info.thetasamp;
else
    error('Unknown P3METHOD [%s]', marg.P3METHOD);
end
mres.post3 = post3info.postmean;

% NEW for comparison
mres.post4 = stackobs + ones(prod(marg.DVEC),1) * marg.P3BASEALPHA;

% NOT NOW
% NEED A VALIDATION SET : this creates a new "mobs"
% eval(strcat(marg.VALIDATAFILE, ';'));
% if ~exist('mobs','var')
%    error('%s did not contain mobs variable', marg.VALIDATAFILE);
% end
% validobs = mobs; clear mobs;

% Predict based upon each posterior vector - also compute LAMBDAMAX
% mres.predP1 = dirmsppredfact(marg, validobs, mres.post1);
% mres.predP2 = dirmsppredfact(marg, validobs, mres.post2);
% mres.predP3 = dirmsppredfact(marg, validobs, mres.post3);

% Evaluate true lambdamax set
% mres.truemeas = dirmsppredfact(marg, trueobs, mres.truecount);

% Save the training set in case we need to do some kind of computation
mres.trainobs = stackobs;

% MARGINAL dists : use K-L divergence
mres.truemarg = truemarg;
mres.marg1 = reshape(marger * mres.post1, size(truedist));
mres.marg1 = mres.marg1 / sum(mres.marg1(:,1));
mres.marg2 = reshape(marger * mres.post2, size(truedist));
mres.marg2 = mres.marg2 / sum(mres.marg2(:,1));
mres.marg3 = reshape(marger * mres.post3, size(truedist));
mres.marg3 = mres.marg3 / sum(mres.marg3(:,1));
mres.marg4 = reshape(marger * mres.post4, size(truedist));
mres.marg4 = mres.marg4 / sum(mres.marg4(:,1));

% Compare to mres.truepi if we're doing that, but 
mres.join1 = mres.post1 / sum(mres.post1);
mres.join2 = mres.post2 / sum(mres.post2);
mres.join3 = mres.post3 / sum(mres.post3);
mres.join4 = mres.post4 / sum(mres.post4);

% quantity to use is mres.mldf1.ldfstack
% Compare each to mres.mldftrue.ldfstack using MSE: do it in post
mres.mldftrue = getmargtables(marg.DVEC, mres.truepi);
mres.mldf1 = getmargtables(marg.DVEC, mres.post1);
mres.mldf2 = getmargtables(marg.DVEC, mres.post2);
mres.mldf3 = getmargtables(marg.DVEC, mres.post3);
mres.mldf4 = getmargtables(marg.DVEC, mres.post4);

% Another time
% lambs = [mres.predP1.lambdamax(:) mres.predP2.lambdamax(:) mres.predP3.lambdamax(:)];
% lambdiffs = abs(lambs - repmat(mres.truemeas.lambdamax(:), 1, size(lambs,2)));
% mres.lambscores = sum(lambdiffs, 1) / sum(mres.truemeas.lambdamax(:) ~= 0);
% mres.predscores = [sum(mres.predP1.predhits(:)) sum(mres.predP2.predhits(:)) sum(mres.predP3.predhits(:))];

% ACCEPTANCE RATE
mres.p3accrate = post3info.accrate;

if isempty(marg.P3METHOD)
    mres.MHMETHOD = 'NOTSPECIFIED';
else
    mres.MHMETHOD = marg.P3METHOD;
end

% Save the result to a file

% Get the base
fname = getenv('OUTPNAME');
if strcmpi(fname((end-2):end),'txt') == 1
    bname = fname(1:(end-4));
else
    bname = fname;
end

% Append the sample index
bname = strcat(bname, sprintf('_%5.5d',marg.sampidx));

setenv('OUTBNAME', bname);

LogIt(sprintf('%s Saving results to %s.mat\n', marg.logstr, bname), 1, marg.LOGTHRESH);
save(bname, 'mres', 'marg');

mres.logstr = 'FINISHED';

end
