Skip to main content

getTreatmentOptimizationInputs.m


% This function is part of the NMSM Pipeline, see file for full license.
%
% This function parses the settings tree resulting from xml2struct from the
% settings XML file common to all treatment optimizatin modules (trackning,
% verification, and design optimization).
%
% (struct) -> (struct, struct)
% returns the input values for all treatment optimization modules


function inputs = getTreatmentOptimizationInputs(tree)
inputs.resultsDirectory = getTextFromField(getFieldByName(tree, ...
'results_directory'));
if(isempty(inputs.resultsDirectory)); inputs.resultsDirectory = pwd; end
inputs.controllerType = getTextFromField(getFieldByNameOrError(tree, ...
'type_of_controller'));
inputs.model = parseModel(tree);
inputs.osimx = parseOsimxFile(getTextFromField(getFieldByName(tree, ...
'osimx_file')));
if strcmp(inputs.controllerType, 'synergy_driven')
inputs.synergyGroups = getSynergyGroups(tree, Model(inputs.model));
inputs.numSynergies = getNumSynergies(inputs.synergyGroups);
inputs.numSynergyWeights = getNumSynergyWeights(inputs.synergyGroups);
inputs.surrogateModelCoordinateNames = parseSpaceSeparatedList(tree, ...
"coordinate_list");
inputs.muscleNames = getMusclesFromCoordinates(inputs.model, ...
inputs.surrogateModelCoordinateNames);
inputs.numMuscles = length(inputs.muscleNames);
inputs.epsilon = str2double(parseElementTextByNameOrAlternate(tree, ...
"epsilon", "1e-4"));
inputs.vMaxFactor = str2double(parseElementTextByNameOrAlternate(tree, ...
"v_max_factor", "10"));
surrogateModelCoefficients = load(getTextFromField(getFieldByName(tree, ...
'surrogate_model_coefficients')));
inputs.coefficients = surrogateModelCoefficients.coefficients;
inputs = getModelOrOsimxInputs(inputs);
elseif strcmp(inputs.controllerType, 'torque_driven')
inputs.controlTorqueNames = parseSpaceSeparatedList(tree, ...
"coordinate_list");
inputs.numTorqueControls = length(inputs.controlTorqueNames);
end
inputs.optimizeSynergyVectors = getBooleanLogic(...
parseElementTextByNameOrAlternate(tree, "optimize_synergy_vectors", 0));
inputs = parseTreatmentOptimizationDataDirectory(tree, inputs);
inputs.initialGuess = getGpopsInitialGuess(tree);
% inputs.experimentalTime = inputs.experimentalTime / ...
% inputs.experimentalTime(end);
inputs.costTerms = parseRcnlCostTermSet( ...
getFieldByNameOrError(tree, 'RCNLCostTermSet').RCNLCostTerm);
inputs.path = getPathConstraintTerms(tree);
inputs.terminal = getTerminalConstraintTerms(tree);
contactSurfaces = getFieldByName(inputs.osimx, "contactSurface");
if (isstruct(contactSurfaces) || iscell(contactSurfaces)) && ...
isfield(inputs, "grfFileName")
inputs.contactSurfaces = prepareGroundContactSurfaces(inputs.model, ...
contactSurfaces, inputs.grfFileName);
else
inputs.contactSurfaces = {};
end
end