% This function is part of the NMSM Pipeline, see file for full license.
%
% This function returns the vector of SynX variables (including both SynX
%   and residual vector weights)
%
% data:
%   nSynX - number of synergies - double
%   nTrials - number of trials in total - double
%   nSynXEMG - number of SynX EMG channels in total - double
%   nMeasuredEMG - number of measured EMG channels in total - double
%   nTasks -  number of tasks - double
%   SynXalgorithm - matrix factorization algorithm - 'PCA' or 'NMF'
%   SynXCategory - variability of synergy vector weights across trials for
%     SynX reconstruction
%   ResiduakCategory  - variability of synergy vector weights across 
%     trials for residual excitation reconstruction
%
% returns linear inequality constraints A and b for optimization using
% fmincon
function [synergyWeights, numberOfExtrapolationWeights, ...
    numberOfResidualWeights] = getSynergyWeights(params, numberOfTrials, ...
    numberOfMeasuredEmgChannels, numberOfUnmeasuredEmgChannels)
% Construct variable vector for SynX
extrapolationWeights = getSizeMatrix(params.synergyExtrapolationCategorization, ...
    params.matrixFactorizationMethod, params.numberOfSynergies, ...
    numberOfUnmeasuredEmgChannels, length(params.taskNames), numberOfTrials);
% Construct variable vector for residual excitations
residualWeights = getSizeMatrix(params.residualCategorization, ...
    params.matrixFactorizationMethod, params.numberOfSynergies, ...
    numberOfMeasuredEmgChannels, length(params.taskNames), numberOfTrials);
synergyWeights= [extrapolationWeights(:)', residualWeights(:)'];
numberOfExtrapolationWeights = numel(extrapolationWeights);
numberOfResidualWeights = numel(residualWeights);
end
function matrixPlaceholder = getSizeMatrix(catergorizationMethod, ...
    factorizationMethod, numberOfSynergies, numberOfEmgChannels, ...
    numberOfTasks, numberOfTrials)
if strcmpi(catergorizationMethod, 'subject')
    matrixPlaceholder = createSizeMatrix( ...
        factorizationMethod, numberOfSynergies, numberOfEmgChannels, ...
        1);
elseif strcmpi(catergorizationMethod, 'task')
    matrixPlaceholder = createSizeMatrix( ...
        factorizationMethod, numberOfSynergies, numberOfEmgChannels, ...
        numberOfTasks);
elseif strcmpi(catergorizationMethod, 'trial')
    matrixPlaceholder = createSizeMatrix( ...
        factorizationMethod, numberOfSynergies, numberOfEmgChannels, ...
        numberOfTrials);
end
end
function matrixPlaceholder = createSizeMatrix( ...
    factorizationMethod, numberOfSynergies, numberOfEmgChannels, ...
    sizeThirdDimension)
if strcmpi(factorizationMethod, 'PCA')
    matrixPlaceholder = zeros(numberOfSynergies + 1, ...
        numberOfEmgChannels, sizeThirdDimension);
elseif strcmpi(factorizationMethod, 'NMF')
    matrixPlaceholder = zeros(numberOfSynergies, ...
        numberOfEmgChannels, sizeThirdDimension);
end    
end