Skip to main content

parseJointModelPersonalizationSettingsTree.m


% This function is part of the NMSM Pipeline, see file for full license.
%
% This function parses the settings tree resulting from xml2struct of the
% Joint Model Personalization Settings XML file.
%
% (struct) -> (string, struct, struct)
% returns the input values for Joint Model Personalization


function [outputFile, inputs, params] = ...
parseJointModelPersonalizationSettingsTree(settingsTree)
% check arguments in advance to quit before computation if incorrect
outputFile = getOutputFile(settingsTree);
inputs = getInputs(settingsTree);
params = getParams(settingsTree);
end

function outputFile = getOutputFile(tree)
outputFile = getFieldByNameOrError(tree, 'output_model_file').Text;
resultsDir = getFieldByName(tree, 'results_directory').Text;
if(resultsDir)
if ~exist(resultsDir, 'dir')
try
mkdir(resultsDir)
catch
throw(MException('', "Cannot find output directory " + ...
resultsDir))
end
end
outputFile = fullfile(resultsDir, outputFile);
else
outputFile = fullfile(pwd, outputFile);
end
end

function inputs = getInputs(tree)
inputs.model = parseModel(tree);
model = Model(inputs.model);
inputs.tasks = getTasks(model, tree);
inputs.desiredError = ...
str2num(getFieldByNameOrError(tree, 'allowable_error').Text);
end

function inputs = getTasks(model, tree)
tasks = getFieldByNameOrError(tree, 'JMPTaskList');
counter = 1;
jmpTasks = orderByIndex(tasks.JMPTask);
for i=1:length(jmpTasks)
if(length(jmpTasks) == 1)
task = jmpTasks;
else
task = jmpTasks{i};
end
if strcmpi(task.is_enabled.Text, 'true')
inputs{counter} = getTask(model, task);
counter = counter + 1;
end
end
end

function output = getTask(model, tree)
output.markerFile = tree.marker_file_name.Text;
timeRange = getFieldByName(tree, 'time_range');
if(isstruct(timeRange))
timeRange = strsplit(timeRange.Text, ' ');
output.startTime = str2double(timeRange{1});
output.finishTime = str2double(timeRange{2});
end
output.parameters = {};
if(isstruct(getFieldByName(tree, "JMPJointSet")))
output.parameters = getJointParameters(tree.JMPJointSet);
end
output.scaling = [];
output.markers = [];
if(isstruct(getFieldByName(tree, "JMPBodySet")))
[output.scaling, output.markers] = ...
getBodyParameters(tree.JMPBodySet, model);
end
translationBounds = getFieldByName(tree, 'translation_bounds');
if(isstruct(translationBounds))
translationBounds = str2double(translationBounds.Text);
end
orientationBounds = getFieldByName(tree, 'orientation_bounds');
if(isstruct(orientationBounds))
orientationBounds = str2double(orientationBounds.Text);
end
output.initialValues = getInitialValues(model, output.parameters, ...
output.scaling, output.markers);
if(translationBounds || orientationBounds)
[output.lowerBounds, output.upperBounds] = getBounds(...
output.parameters, output.initialValues, ...
translationBounds, orientationBounds, output.scaling, ...
output.markers);
end
end

% this function is long and ugly but is a rote and imperative way to
% solve this problem, it's fine
function inputs = getJointParameters(jointSetTree)
inputs = {};
if isfield(jointSetTree, "JMPJoint")
jointTree = jointSetTree.JMPJoint;
counter = 1; % for index of parameter in output
for i=1:length(jointTree)
if(length(jointTree) == 1)
joint = jointTree;
else
joint = jointTree{i};
end
jointName = joint.Attributes.name;
parentTrans = strsplit( ...
joint.parent_frame_transformation.translation.Text, ' ');
verifyLength(parentTrans, 3);
for j=0:2
if(strcmp(parentTrans{j+1}, 'true'))
inputs{counter} = {jointName, true, true, j};
counter = counter + 1;
end
end
parentOrient = strsplit( ...
joint.parent_frame_transformation.orientation.Text, ' ');
verifyLength(parentOrient, 3);
for j=0:2
if(strcmp(parentOrient{j+1}, 'true'))
inputs{counter} = {jointName, true, false, j};
counter = counter + 1;
end
end
childTrans = strsplit( ...
joint.child_frame_transformation.translation.Text, ' ');
verifyLength(childTrans, 3);
for j=0:2
if(strcmp(childTrans{j+1}, 'true'))
inputs{counter} = {jointName, false, true, j};
counter = counter + 1;
end
end
childOrient = strsplit( ...
joint.child_frame_transformation.orientation.Text, ' ');
verifyLength(childOrient, 3);
for j=0:2
if(strcmp(childOrient{j+1},'true'))
inputs{counter} = {jointName, false, false, j};
counter = counter + 1;
end
end
end
end
end

function [scaling, markers] = getBodyParameters( ...
bodySetTree, model)
if isfield(bodySetTree, "JMPBody")
bodyTree = bodySetTree.JMPBody;
scaling = getScalingBodies(bodyTree);
markers = getMarkers(bodyTree, model);
else
scaling = string([]);
markers = {};
end
end

function inputs = getScalingBodies(bodyTree)
inputs = string([]);
for i=1:length(bodyTree)
if(length(bodyTree) == 1)
body = bodyTree;
else
body = bodyTree{i};
end
bodyName = body.Attributes.name;
scaleBodies = strcmp(getFieldByNameOrError(body, ...
"scale_body").Text, "true");
if(scaleBodies)
inputs(end + 1) = bodyName;
end
end
end

function output = getMarkers(bodyTree, model)
axesNames = ["x", "y", "z"];
output = {};
for i=1:length(bodyTree)
if(length(bodyTree) == 1)
body = bodyTree;
else
body = bodyTree{i};
end
bodyName = body.Attributes.name;
axesStrings = parseSpaceSeparatedList(body, "move_markers");
axes = zeros(1, 3);
for j = 1:3
axes(j) = axesStrings(j) == "true";
end
if(axes(1) || axes(2) || axes(3))
markers = getMarkersFromBody(model, bodyName);
for j = 1:length(markers)
for k = 1:3
if(axes(k))
output{end + 1} = [markers(j), axesNames(k)];
end
end
end
end
end
end

function output = getInitialValues(model, parameters, scaling, markers)
output = [];
for i = 1 : length(parameters)
temp = parameters{i};
output(i) = getFrameParameterValue(model, temp{1}, ...
temp{2}, temp{3}, temp{4});
end
for i = 1 : length(scaling)
output(end + 1) = getScalingParameterValue(model, scaling(i));
end
for i = 1 : length(markers)
marker = markers{i};
[xPosition, yPosition, zPosition] = getMarkerParameterValues( ...
model, marker(1));
axis = marker(2);
if strcmp(axis, "x") output(end + 1) = xPosition; end
if strcmp(axis, "y") output(end + 1) = yPosition; end
if strcmp(axis, "z") output(end + 1) = zPosition; end
end
end

function [lowerBounds, upperBounds] = getBounds(parameters, ...
initialValues, translationBounds, orientationBounds, scaling, markers)
lowerBounds = [];
upperBounds = [];
for i=1:length(parameters)
if(parameters{i}{3})
lowerBounds(i) = initialValues(i) - translationBounds;
upperBounds(i) = initialValues(i) + translationBounds;
else
lowerBounds(i) = initialValues(i) - orientationBounds;
upperBounds(i) = initialValues(i) + orientationBounds;
end
end
for i = 1 : length(scaling)
lowerBounds(end + 1) = -Inf;
upperBounds(end + 1) = Inf;
end
for i = 1 : length(markers) % double values for X and Z directions
lowerBounds(end + 1) = -Inf;
upperBounds(end + 1) = Inf;
end
end

function output = getParams(tree)
output = struct();
paramArgs = ["accuracy", "diff_min_change", "optimality_tolerance", ...
"function_tolerance", "step_tolerance", "max_function_evaluations"];
% name in matlab is different, use for output struct arg name
paramName = ["accuracy", "diffMinChange", "optimalityTolerance", ...
"functionTolerance", "stepTolerance", "maxFunctionEvaluations"];
for i=1:length(paramArgs)
value = getFieldByName(tree, paramArgs(i));
if(isstruct(value))
output.(paramName(i)) = str2double(value.Text);
end
end
end