% This function is part of the NMSM Pipeline, see file for full license.
%
% 
%
% (struct, Array of string, struct, struct, double) -> (Array of double)
% Calculate cost for a Ground Contact Personalization task. 
function cost = calcGroundContactPersonalizationTaskCost( ...
    values, fieldNameOrder, inputs, params, task)
% OpenSim models are non-serializable objects, so they cannot normally be
% passed to a parallel pool. Using a persistent variable, the models are
% opened once per parallel worker and repeatedly accessed. 
persistent models;
for foot = 1:length(inputs.surfaces)
    if ~isfield(models, "model_" + foot)
        models.("model_" + foot) = Model(inputs.surfaces{foot}.model);
    end
end
valuesStruct = unpackValues(values, inputs, fieldNameOrder);
% If a design variable is not included, its static value from inputs is
% added to valuesStruct so it can be used to calculate cost if needed. 
if ~params.tasks{task}.designVariables(1)
        valuesStruct.springConstants = inputs.springConstants;
end
if ~params.tasks{task}.designVariables(2)
        valuesStruct.dampingFactor = inputs.dampingFactor;
end
if ~params.tasks{task}.designVariables(3)
        valuesStruct.dynamicFrictionCoefficient = ...
            inputs.dynamicFrictionCoefficient;
end
if ~params.tasks{task}.designVariables(4)
        valuesStruct.viscousFrictionCoefficient = ...
            inputs.viscousFrictionCoefficient;
end
if ~params.tasks{task}.designVariables(5)
        valuesStruct.restingSpringLength = ...
            inputs.restingSpringLength;
end
if ~params.tasks{task}.designVariables(6)
    for foot = 1:length(inputs.surfaces)
        field = "bSplineCoefficients" + foot;
        valuesStruct.(field) = inputs.surfaces{foot}.bSplineCoefficients;
    end
end
cost = [];
for foot = 1:length(inputs.surfaces)
    field = "bSplineCoefficients" + foot;
    valuesBSplineCoefficients = ...
        reshape(valuesStruct.(field), [], 7);
    [modeledJointPositions, modeledJointVelocities] = ...
        calcGCPJointKinematics(inputs.surfaces{foot} ...
        .experimentalJointPositions, inputs.surfaces{foot} ...
        .jointKinematicsBSplines, valuesBSplineCoefficients);
    modeledValues = calcGCPModeledValues(inputs, valuesStruct, ...
        modeledJointPositions, modeledJointVelocities, params, task, ...
        foot, models);
    modeledValues.jointPositions = modeledJointPositions;
    modeledValues.jointVelocities = modeledJointVelocities;
    cost = [cost calcCost(inputs, params, modeledValues, valuesStruct, ...
        task, foot)];
end
end
% Reformats the values array of design variables to a simpler struct. 
function valuesStruct = unpackValues(values, inputs, fieldNameOrder)
valuesStruct = struct();
start = 1;
for i=1:length(fieldNameOrder)
    % Kinematics are specific to each foot, but other design variables are
    % shared. 
    if contains(fieldNameOrder(i), "bSplineCoefficients")
        foot = convertStringsToChars(fieldNameOrder(i));
        foot = str2double(foot(end));
        valuesStruct.(fieldNameOrder(i)) = values(start:start + ...
            numel(inputs.surfaces{foot}.bSplineCoefficients) - 1);
        start = start + numel(inputs.surfaces{foot}.bSplineCoefficients);
    else
        valuesStruct.(fieldNameOrder(i)) = values(start:start + ...
            numel(inputs.(fieldNameOrder(i))) - 1);
        if fieldNameOrder(i) == "springConstants"
            valuesStruct.(fieldNameOrder(i)) = ...
                1000 * valuesStruct.(fieldNameOrder(i));
        end
        start = start + numel(inputs.(fieldNameOrder(i)));
    end
end
end
% Calculates the overall cost using allowable errors and all included cost
% terms. 
function cost = calcCost(inputs, params, modeledValues, valuesStruct, ...
    task, foot)
cost = [];
% Prepare reused cost calculations
includedCostTypes = [];
for term = 1:length(params.tasks{task}.costTerms)
    if params.tasks{task}.costTerms{term}.isEnabled
        includedCostTypes = [includedCostTypes convertCharsToStrings( ...
            params.tasks{task}.costTerms{term}.type)];
    end
end
if ~isempty(intersect(includedCostTypes, ...
        ["marker_position" "marker_slope"]))
    [footMarkerPositionError, footMarkerSlopeError] = ...
        calcFootMarkerPositionAndSlopeError(inputs.surfaces{foot}, ...
        modeledValues);
end
if ~isempty(intersect(includedCostTypes, ["vertical_grf" ...
        "vertical_grf_slope" "horizontal_grf" "horizontal_grf_slope"]))
    if ~isfield(modeledValues, 'anteriorGrf')
        modeledValues.anteriorGrf = zeros(size(modeledValues.verticalGrf));
        modeledValues.lateralGrf = zeros(size(modeledValues.verticalGrf));
    end
    [groundReactionForceValueErrors, groundReactionForceSlopeErrors] = ...
        calcGroundReactionForceAndSlopeError(inputs.surfaces{foot}, ...
        modeledValues);
end
if ~isempty(intersect(includedCostTypes, ...
        ["ground_reaction_moment" "ground_reaction_moment_slope"]))
    [groundReactionMomentErrors, groundReactionMomentSlopeErrors] = ...
        calcGroundReactionMomentAndSlopeError(inputs.surfaces{foot}, ...
        modeledValues); 
end
% Append all cost terms
for term = 1:length(params.tasks{task}.costTerms)
    costTerm = params.tasks{task}.costTerms{term};
    if costTerm.isEnabled
        switch costTerm.type
            case "marker_position"
                rawCost = footMarkerPositionError;
            case "marker_slope"
                rawCost = footMarkerSlopeError;
            case "rotation"
                rawCost = reshape(rad2deg(modeledValues ...
                    .jointPositions(1:4, :)) - rad2deg(inputs ...
                    .surfaces{foot}.experimentalJointPositions(1:4, ...
                    :)), 1, []);
            case "translation"
                rawCost = reshape(modeledValues.jointPositions(5:7, :) ...
                    - inputs.surfaces{foot}.experimentalJointPositions( ...
                    5:7, :), 1, []);
            case "vertical_grf"
                rawCost = groundReactionForceValueErrors(2, :);
            case "vertical_grf_slope"
                rawCost = groundReactionForceSlopeErrors(2, :);
            case "horizontal_grf"
                rawCost = reshape(groundReactionForceValueErrors([1 3], ...
                    :), 1, []);
            case "horizontal_grf_slope"
                rawCost = reshape(groundReactionForceSlopeErrors([1 3], ...
                    :), 1, []);
            case "ground_reaction_moment"
                rawCost = reshape(groundReactionMomentErrors, 1, []);
            case "ground_reaction_moment_slope"
                rawCost = reshape(groundReactionMomentSlopeErrors, 1, []);
            case "spring_constant_mean"
                rawCost = calcSpringConstantsErrorFromMean( ...
                    valuesStruct.springConstants);
            case "neighbor_spring_constant"
                rawCost = (calcSpringConstantsErrorFromNeighbors( ...
                    valuesStruct.springConstants, ...
                    modeledValues.gaussianWeights) / ...
                    costTerm.maxAllowableError) .^ 4 * ...
                    costTerm.maxAllowableError;
            otherwise
                throw(MException('', ['Cost term type ' costTerm.type ...
                    ' does not exist for this tool.']))
        end
        cost = [cost sqrt(1 / length(rawCost)) * ...
            1 / costTerm.maxAllowableError * rawCost];
    end
end
end