% This function is part of the NMSM Pipeline, see file for full license.
%
% This function parses the settings tree resulting from xml2struct from the
% settings XML file common to all treatment optimizatin modules (trackning,
% verification, and design optimization).
%
% (struct) -> (struct, struct)
% returns the input values for all treatment optimization modules
function inputs = getTreatmentOptimizationInputs(tree)
inputs.resultsDirectory = getTextFromField(getFieldByName(tree, ...
    'results_directory'));
if(isempty(inputs.resultsDirectory)); inputs.resultsDirectory = pwd; end
inputs.controllerType = getTextFromField(getFieldByNameOrError(tree, ...
    'type_of_controller'));
inputs.model = parseModel(tree);
inputs.osimx = parseOsimxFile(getTextFromField(getFieldByName(tree, ...
    'osimx_file')));
if strcmp(inputs.controllerType, 'synergy_driven')
    inputs.synergyGroups = getSynergyGroups(tree, Model(inputs.model));
    inputs.numSynergies = getNumSynergies(inputs.synergyGroups);
    inputs.numSynergyWeights = getNumSynergyWeights(inputs.synergyGroups);
    inputs.surrogateModelCoordinateNames = parseSpaceSeparatedList(tree, ...
        "coordinate_list");
    inputs.muscleNames = getMusclesFromCoordinates(inputs.model, ...
        inputs.surrogateModelCoordinateNames);
    inputs.numMuscles = length(inputs.muscleNames);
    inputs.epsilon = str2double(parseElementTextByNameOrAlternate(tree, ...
        "epsilon", "1e-4"));
    inputs.vMaxFactor = str2double(parseElementTextByNameOrAlternate(tree, ...
        "v_max_factor", "10"));
    surrogateModelCoefficients = load(getTextFromField(getFieldByName(tree, ...
        'surrogate_model_coefficients')));
    inputs.coefficients = surrogateModelCoefficients.coefficients;
    inputs = getModelOrOsimxInputs(inputs);
elseif strcmp(inputs.controllerType, 'torque_driven')
    inputs.controlTorqueNames = parseSpaceSeparatedList(tree, ...
        "coordinate_list");
    inputs.numTorqueControls = length(inputs.controlTorqueNames);
end
inputs.optimizeSynergyVectors = getBooleanLogic(...
    parseElementTextByNameOrAlternate(tree, "optimize_synergy_vectors", 0));
inputs = parseTreatmentOptimizationDataDirectory(tree, inputs);
inputs.initialGuess = getGpopsInitialGuess(tree);
% inputs.experimentalTime = inputs.experimentalTime / ...
%     inputs.experimentalTime(end);
inputs.costTerms = parseRcnlCostTermSet( ...
    getFieldByNameOrError(tree, 'RCNLCostTermSet').RCNLCostTerm);
inputs.path = getPathConstraintTerms(tree);
inputs.terminal = getTerminalConstraintTerms(tree);
contactSurfaces = getFieldByName(inputs.osimx, "contactSurface");
if (isstruct(contactSurfaces) || iscell(contactSurfaces)) && ...
        isfield(inputs, "grfFileName")
    inputs.contactSurfaces = prepareGroundContactSurfaces(inputs.model, ...
        contactSurfaces, inputs.grfFileName);
else
    inputs.contactSurfaces = {};
end
end