diff --git a/src/main/java/ca/mcscert/jtet/expression/VariableCollection.java b/src/main/java/ca/mcscert/jtet/expression/VariableCollection.java index 35b2e82749df91635ac50898ce071198115b9209..76cae2df93cf98e11ba29e28428e1c03d25835e9 100644 --- a/src/main/java/ca/mcscert/jtet/expression/VariableCollection.java +++ b/src/main/java/ca/mcscert/jtet/expression/VariableCollection.java @@ -163,6 +163,47 @@ final public class VariableCollection { return getAllVariables().values(); } + /** + * Helper function to make a set of all variable types from a given collection. + * @param variables Variables to extract types from. + * @return The collected type set. + */ + private Set getTypesFromVariableCollection(Collection variables) { + final Set ret = new LinkedHashSet(); + + for (Variable var : variables) { + ret.add(var.type()); + } + + return ret; + } + + /** + * Returns a set containing all the variable types used in the input variables. + * @return A set of used variable types. + */ + public Set getInputVariableTypes() { + return getTypesFromVariableCollection(getInputVariablesList()); + } + + /** + * Returns a set containing all the variable types used in the output variables. + * @return A set of used variable types. + */ + public Set getOutputVariableTypes() { + return getTypesFromVariableCollection(getOutputVariablesList()); + } + + /** + * Returns a set containing all the variable types used. + *

+ * This is equivalant to the union of {@link #getInputVariableTypes()} and {@link #getOutputVariableTypes()}. + * @return A set of used variable types. + */ + public Set getAllVariableTypes() { + return getTypesFromVariableCollection(getAllVariablesList()); + } + private final Map m_inputVariables = new LinkedHashMap(); private final Map m_outputVariables = new LinkedHashMap(); private final Map m_variablesAndEnumeratedValues = new HashMap(); diff --git a/src/test/java/ca/mcscert/jtet/expression/test/VariableCollectionTest.java b/src/test/java/ca/mcscert/jtet/expression/test/VariableCollectionTest.java index 66ba0a1ca16e4cb6bb868cca94eeedc299b1123b..b3f1a0f4c417dc283dd6850dedf6ff18cf899145 100644 --- a/src/test/java/ca/mcscert/jtet/expression/test/VariableCollectionTest.java +++ b/src/test/java/ca/mcscert/jtet/expression/test/VariableCollectionTest.java @@ -137,4 +137,41 @@ public class VariableCollectionTest { new VariableCollection(new PartialVariableCollection(var, enums)); // Should throw as e1 the variable conflicts with the enumerated name e1. } + + @Test + public void VariableTypeFetchingWorks() { + final EnumerationVariableType type1 = new EnumerationVariableType("myenum"); + type1.enumerationValues().add("e1_1"); + + final EnumerationVariableType type2 = new EnumerationVariableType("myenum"); + type2.enumerationValues().add("e2_2"); + + final Variable var1e = new Variable("v1e", type1); + final Variable var1r = new Variable("v1r", new RealVariableType()); + final Variable var2e = new Variable("v2e", type2); + + final Set inputTypeSet = new HashSet(2); + inputTypeSet.add(type1); + inputTypeSet.add(new RealVariableType()); + + final Set outputTypeSet = new HashSet(1); + outputTypeSet.add(type2); + + final Set allTypeSet = new HashSet(2); + allTypeSet.addAll(inputTypeSet); + allTypeSet.addAll(outputTypeSet); + + final Map vars1 = new HashMap(); + vars1.put(var1e.name(), var1e); + vars1.put(var1r.name(), var1r); + + final Map vars2 = new HashMap(); + vars2.put(var2e.name(), var2e); + + final VariableCollection vc = new VariableCollection(new PartialVariableCollection(vars1), new PartialVariableCollection(vars2)); + + assertThat(vc.getInputVariableTypes(), equalTo(inputTypeSet)); + assertThat(vc.getOutputVariableTypes(), equalTo(outputTypeSet)); + assertThat(vc.getAllVariableTypes(), equalTo(allTypeSet)); + } }