From 5238c323800edeb9ff74a1ba19bd4ec0b65ac13d Mon Sep 17 00:00:00 2001 From: Riccardo Azzolini Date: Mon, 19 Nov 2018 13:05:28 +0100 Subject: [PATCH] Implement Pattern-based Rule --- .../warppi/math/rules/dsl/PatternRule.java | 70 ++++++++++ .../math/rules/dsl/PatternRuleTest.java | 127 ++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 core/src/main/java/it/cavallium/warppi/math/rules/dsl/PatternRule.java create mode 100644 core/src/test/java/it/cavallium/warppi/math/rules/dsl/PatternRuleTest.java diff --git a/core/src/main/java/it/cavallium/warppi/math/rules/dsl/PatternRule.java b/core/src/main/java/it/cavallium/warppi/math/rules/dsl/PatternRule.java new file mode 100644 index 00000000..8a503421 --- /dev/null +++ b/core/src/main/java/it/cavallium/warppi/math/rules/dsl/PatternRule.java @@ -0,0 +1,70 @@ +package it.cavallium.warppi.math.rules.dsl; + +import it.cavallium.warppi.math.Function; +import it.cavallium.warppi.math.MathContext; +import it.cavallium.warppi.math.rules.Rule; +import it.cavallium.warppi.math.rules.RuleType; +import it.cavallium.warppi.util.Error; +import it.unimi.dsi.fastutil.objects.ObjectArrayList; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * A Rule which uses Patterns to match and replace functions. + */ +public class PatternRule implements Rule { + private final String ruleName; + private final RuleType ruleType; + private final Pattern target; + private final List replacements; + + public PatternRule( + final String ruleName, + final RuleType ruleType, + final Pattern target, + final List replacements + ) { + this.ruleName = ruleName; + this.ruleType = ruleType; + this.target = target; + this.replacements = replacements; + } + + public PatternRule( + final String ruleName, + final RuleType ruleType, + final Pattern target, + final Pattern... replacements + ) { + this(ruleName, ruleType, target, Arrays.asList(replacements)); + } + + @Override + public String getRuleName() { + return ruleName; + } + + @Override + public RuleType getRuleType() { + return ruleType; + } + + @Override + public ObjectArrayList execute(final Function func) throws Error, InterruptedException { + return target.match(func) + .map(subFunctions -> applyReplacements(func.getMathContext(), subFunctions)) + .orElse(null); + } + + private ObjectArrayList applyReplacements( + final MathContext mathContext, + final Map subFunctions + ) { + return replacements.stream() + .map(replacement -> replacement.replace(mathContext, subFunctions)) + .collect(Collectors.toCollection(ObjectArrayList::new)); + } +} diff --git a/core/src/test/java/it/cavallium/warppi/math/rules/dsl/PatternRuleTest.java b/core/src/test/java/it/cavallium/warppi/math/rules/dsl/PatternRuleTest.java new file mode 100644 index 00000000..44832254 --- /dev/null +++ b/core/src/test/java/it/cavallium/warppi/math/rules/dsl/PatternRuleTest.java @@ -0,0 +1,127 @@ +package it.cavallium.warppi.math.rules.dsl; + +import it.cavallium.warppi.math.Function; +import it.cavallium.warppi.math.MathContext; +import it.cavallium.warppi.math.functions.Number; +import it.cavallium.warppi.math.functions.Subtraction; +import it.cavallium.warppi.math.functions.Sum; +import it.cavallium.warppi.math.functions.SumSubtraction; +import it.cavallium.warppi.math.rules.RuleType; +import it.cavallium.warppi.math.rules.dsl.patterns.*; +import it.cavallium.warppi.util.Error; +import it.cavallium.warppi.util.Errors; +import it.unimi.dsi.fastutil.objects.ObjectArrayList; +import org.junit.Test; + +import java.math.BigDecimal; + +import static org.junit.Assert.*; + +public class PatternRuleTest { + private final MathContext mathContext = new MathContext(); + + private final Pattern x = new SubFunctionPattern("x"); + private final Pattern xPlus0 = new SumPattern( + x, + new NumberPattern(new BigDecimal(0)) + ); + + @Test + public void testNonMatching() throws InterruptedException, Error { + final Function func = new Sum( + mathContext, + new Number(mathContext, 1), + new Number(mathContext, 2) + ); + + final PatternRule rule = new PatternRule("TestRule", RuleType.REDUCTION, xPlus0, xPlus0); + assertNull(func.simplify(rule)); + } + + @Test + public void testMatching() throws InterruptedException, Error { + final Function func = new Sum( + mathContext, + new Number(mathContext, 1), + new Number(mathContext, 0) + ); + + final PatternRule identityRule = new PatternRule("Identity", RuleType.REDUCTION, xPlus0, xPlus0); + final ObjectArrayList identityResult = func.simplify(identityRule); + assertEquals(1, identityResult.size()); + assertEquals(func, identityResult.get(0)); + + final PatternRule simplifyRule = new PatternRule("Simplify", RuleType.REDUCTION, xPlus0, x); + final ObjectArrayList simplifyResult = func.simplify(simplifyRule); + assertEquals(1, identityResult.size()); + assertEquals(new Number(mathContext, 1), simplifyResult.get(0)); + } + + @Test + public void testMatchingRecursive() throws InterruptedException, Error { + final Function func = new Sum( + mathContext, + new Number(mathContext, 3), + new Sum( + mathContext, + new Number(mathContext, 5), + new Number(mathContext, 0) + ) + ); + + final PatternRule identityRule = new PatternRule("Identity", RuleType.REDUCTION, xPlus0, xPlus0); + final ObjectArrayList identityResult = func.simplify(identityRule); + assertEquals(1, identityResult.size()); + assertEquals(func, identityResult.get(0)); + + final PatternRule simplifyRule = new PatternRule("Simplify", RuleType.REDUCTION, xPlus0, x); + final ObjectArrayList simplifyResult = func.simplify(simplifyRule); + assertEquals(1, identityResult.size()); + final Function expected = new Sum( + mathContext, + new Number(mathContext, 3), + new Number(mathContext, 5) + ); + assertEquals(expected, simplifyResult.get(0)); + } + + @Test + public void testMultipleReplacements() throws InterruptedException, Error { + final Number one = new Number(mathContext, 1); + final Number two = new Number(mathContext, 2); + final Function func = new SumSubtraction(mathContext, one, two); + + final Pattern x = new SubFunctionPattern("x"); + final Pattern y = new SubFunctionPattern("y"); + final PatternRule rule = new PatternRule( + "TestRule", + RuleType.EXPANSION, + new SumSubtractionPattern(x, y), + new SumPattern(x, y), new SubtractionPattern(x, y) + ); + + final ObjectArrayList result = func.simplify(rule); + final ObjectArrayList expected = ObjectArrayList.wrap(new Function[]{ + new Sum(mathContext, one, two), + new Subtraction(mathContext, one, two) + }); + assertEquals(expected, result); + } + + @Test + public void testNoReplacements() throws InterruptedException, Error { + final Function func = new Sum( + mathContext, + new Number(mathContext, 1), + new Number(mathContext, 2) + ); + + final PatternRule rule = new PatternRule( + "TestRule", + RuleType.REDUCTION, + new SubFunctionPattern("x") + ); + + assertTrue(func.simplify(rule).isEmpty()); + } +}