% -----------------------------------------------------------------------
% getmargtables
% AUTHOR : DAVID KESSLER, dkessler@live.unc.edu
%
% Figures out the two-at-a-time marginal sums for a P-way table.
%
% dvec: indicates dimension of each position
% pvec: the probability vector of interest
%
% This can do things in situ or provide the necessary matrices such that
% ldfvec = mres.lsummer * log(mres.msummer * pvec)
% -----------------------------------------------------------------------

% Condition on the values in P

function mres = getmargtables(dvec, pvec)

if numel(dvec < 3)
    dvec = [ones(1, 3 - numel(dvec))  dvec];
end

Sp = prod(dvec);

% The number of things to compute depends on the margins.

% Find binary numbers with 2 bits set - these are the ones we want
bindvec = ones(1, numel(dvec))*2;
seltab = idx2vec(1:prod(bindvec), bindvec) - 1;
maskb = seltab(sum(seltab,2)==2, :);

% Locations that vary
mblk = sort((maskb==1) .* repmat(1:numel(dvec), size(maskb,1), 1), 2);
mvary = mblk(:, (numel(dvec)- 2 + 1):end);

% Locations that are marginalized over
mblk = sort((maskb==0) .* repmat(1:numel(dvec), size(maskb,1), 1), 2);
mmarg = mblk(:, 3:end);

% Qualities of the varying indices will tell us how many slots we need
dvary = dvec(mvary);
nlocdep = sum(prod(dvary,2) - 1 - sum(dvary-1, 2));
nxcell = sum(prod(dvary,2));

% Matrix telling how to add up joint distribution to get different r x c
% tables
mres.msummer = zeros(nxcell, Sp);

% Matrix telling how to add up the LOGS of the result of applying
% mres.msummer to the pvec: reshape-able into something with
% 4 rows or 4 cols, whatever.
mres.lsummer = zeros(nlocdep, nxcell);

mres.mskey = zeros(nxcell, 4);
mres.sumidx = zeros(size(mvary,1),2);

sidx = 0;

% The full index set
mvecs = idx2vec(1:Sp, dvec);

% Information for the caller
mres.xtpos = zeros(2, size(mvary,1));
mres.xtdims = zeros(2, size(mvary,1));
mres.xtabs = cell(1, size(mvary,1));
mres.ldfinfo = zeros(4, size(mvary,1));
mres.ldf = cell(1, size(mvary,1));
mres.xtrow = cell(1, size(mvary,1));
mres.xtrpos = cell(1, size(mvary,1));
mres.xtcol = cell(1, size(mvary,1));
mres.xtcpos = cell(1, size(mvary,1));

% Tracking entries
mres.ldfstack = zeros(1, nlocdep);

% Which cols are in play
mres.ldfjays = zeros(nlocdep, 2);

% Which offsets within those cols
mres.ldfcees = zeros(nlocdep, 2);

% Which table the ldf is from
mres.ldffidx = zeros(nlocdep, 1);

% Not as clear how to figure out how to compute all contributing indices

slidx = 1;
for fidx = 1 : size(mmarg,1)

    % Figure out what our values of interest are
    lcldvec = dvec(mvary(fidx,:));
    varvec = idx2vec(1:prod(lcldvec), lcldvec);
    sgpidx = sidx + 1;
    
    % Compute the summing entries for each
    for vidx = 1 : size(varvec,1)
        sidx = sidx + 1;
        mres.msummer(sidx,:) = (vec2idx(mvecs(:,mvary(fidx,:)),lcldvec) == vidx) * 1;
        mres.mskey(sidx,:) = [mvary(fidx,:) varvec(vidx,:)];
    end
    
    fgpidx = sidx;
    mres.sumidx(fidx,:) = [sgpidx fgpidx];
    
    % Apply the summer to the probability matrix and reshape it
    rawmargs = mres.msummer(sgpidx:fgpidx,:) * pvec;
    
    mres.xtpos(:,fidx) = mvary(fidx,[2 1]);
    mres.xtdims(:,fidx) = lcldvec([2 1]);
    curxtab = reshape(rawmargs, lcldvec(2), lcldvec(1));
    mres.xtabs{fidx} = curxtab;
    
    mres.xtrow{fidx} = reshape(mres.mskey(sgpidx:fgpidx,4), lcldvec(2), lcldvec(1));
    mres.xtrpos{fidx} = reshape(mres.mskey(sgpidx:fgpidx,2), lcldvec(2), lcldvec(1));
    mres.xtcol{fidx} = reshape(mres.mskey(sgpidx:fgpidx,3), lcldvec(2), lcldvec(1));
    mres.xtcpos{fidx} = reshape(mres.mskey(sgpidx:fgpidx,1), lcldvec(2), lcldvec(1));
    
    curldf = zeros(1, prod(mres.xtdims(:,fidx)-1));
    lidx = 0;
    
    % Very likely a good way to do this without the loops
    % Given the current crosstabulation, compute the non-redundant LDF
    % mldf.xtdims(1,fidx) = rows,  mldf.xtdims(2,fidx)=cols
    
    for ridx = 1 : (mres.xtdims(1,fidx)-1)
        for cidx = 1: (mres.xtdims(2,fidx)-1)
            % 1 2 3 4 of the usual 2x2 table
            mfacts = [curxtab(ridx,cidx) curxtab(ridx,cidx+1) curxtab(ridx+1,cidx) curxtab(ridx+1,cidx+1)];
            lidx = lidx + 1;
            curldf(lidx) = log(mfacts(1)) + log(mfacts(4)) - log(mfacts(2)) - log(mfacts(3));
            lsidx = vec2idx([cidx ridx; cidx+1 ridx; cidx ridx+1; cidx+1 ridx+1], mres.xtdims([2 1],fidx)) + sgpidx - 1;
            mres.lsummer(slidx + lidx - 1, lsidx) = [1 -1 -1 1];
            mres.ldfjays(slidx + lidx - 1, :) = [mres.xtrpos{fidx}(ridx,cidx)  mres.xtcpos{fidx}(ridx,cidx)];
            mres.ldfcees(slidx + lidx - 1, :) = [ridx  cidx];
            mres.ldffidx(slidx+lidx-1) = fidx;
        end    
    end

    % Vectorize approach
    % [mR,mC] = meshgrid(1 : (mres.xtdims(1,fidx)-1),  1: (mres.xtdims(2,fidx)-1));
    % vR = mR(:);
    % vC = mC(:);
    
    mres.ldf{fidx} = curldf;
    elidx = slidx + numel(curldf) - 1;
    mres.ldfstack(slidx:elidx) = curldf;
    slidx = elidx + 1;
end

mres.mvary = mvary;
mres.nsets = size(mvary,1);

end
