/*
 * Decompiled with CFR 0.152.
 */
package org.tsers.junitquest.solver;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.tsers.junitquest.CallParam;
import org.tsers.junitquest.Jutil;
import org.tsers.junitquest.expr.AddNode;
import org.tsers.junitquest.expr.DivisionNode;
import org.tsers.junitquest.expr.EqualNode;
import org.tsers.junitquest.expr.ExprNode;
import org.tsers.junitquest.expr.GetFieldNode;
import org.tsers.junitquest.expr.GreaterThanEqNode;
import org.tsers.junitquest.expr.GreaterThanNode;
import org.tsers.junitquest.expr.IntNode;
import org.tsers.junitquest.expr.InvokeVirtualNode;
import org.tsers.junitquest.expr.LessThanEqNode;
import org.tsers.junitquest.expr.LessThanNode;
import org.tsers.junitquest.expr.LocalNode;
import org.tsers.junitquest.expr.MinusNode;
import org.tsers.junitquest.expr.MultiplicationNode;
import org.tsers.junitquest.expr.NotEqualNode;
import org.tsers.junitquest.instance.PrimitiveInstance;
import org.tsers.junitquest.solver.InstanceOfSolver;
import org.tsers.junitquest.solver.LogicalSolver;

public class ArithmeticSolver {
    private static final int MAX_REDUCTION_TRYS = 5;
    public static Function<ExprNode, ExprNode> complement = node -> {
        if (node instanceof IntNode) {
            int v = ((IntNode)node).getValue();
            return new IntNode(-v);
        }
        if (node instanceof MinusNode) {
            return node.getChildren().get(0);
        }
        if (node instanceof AddNode) {
            return Jutil.mapChildren.apply((ExprNode)node, complement);
        }
        return new MinusNode((ExprNode)node);
    };
    public static Function<ExprNode, ExprNode> reduceComplement = minusNode -> {
        if (!(minusNode instanceof MinusNode)) {
            return Jutil.mapChildren.apply((ExprNode)minusNode, reduceComplement);
        }
        ExprNode child = minusNode.getChildren().get(0);
        if (child instanceof MinusNode) {
            return reduceComplement.apply(child.getChildren().get(0));
        }
        if (child instanceof IntNode) {
            return new IntNode(-((IntNode)child).getValue());
        }
        if (child instanceof AddNode) {
            return Jutil.mapChildren.apply(child, complement);
        }
        if (child instanceof MultiplicationNode) {
            return Jutil.mapFirstChild.apply(child, complement);
        }
        return minusNode;
    };
    public static BiFunction<ExprNode, Integer, ExprNode> multiply = (node, multiplier) -> {
        if (node instanceof DivisionNode) {
            if (node.getChildren().stream().filter(n2 -> n2 instanceof IntNode && ((IntNode)n2).getValue() == multiplier.intValue()).findAny().isPresent()) {
                ExprNode variable = node.getChildren().stream().filter(n2 -> !(n2 instanceof IntNode) || ((IntNode)n2).getValue() != multiplier.intValue()).findAny().get();
                return variable;
            }
            return new MultiplicationNode(Arrays.asList(node, new IntNode((int)multiplier)));
        }
        if (node instanceof IntNode) {
            return new IntNode(((IntNode)node).getValue() * multiplier);
        }
        return new MultiplicationNode(Arrays.asList(node, new IntNode((int)multiplier)));
    };
    public static BiFunction<ExprNode, Integer, ExprNode> divide = (node, denominator) -> {
        if (node instanceof MultiplicationNode) {
            if (node.getChildren().stream().filter(n2 -> n2 instanceof IntNode && ((IntNode)n2).getValue() == denominator.intValue()).findAny().isPresent()) {
                ExprNode variable = node.getChildren().stream().filter(n2 -> !(n2 instanceof IntNode) || ((IntNode)n2).getValue() != denominator.intValue()).findAny().get();
                return variable;
            }
            return new DivisionNode(Arrays.asList(node, new IntNode((int)denominator)));
        }
        if (node instanceof IntNode) {
            if (denominator == 0) {
                return new IntNode(new Random().nextInt());
            }
            return new IntNode(((IntNode)node).getValue() / denominator);
        }
        return new DivisionNode(Arrays.asList(node, new IntNode((int)denominator)));
    };
    private static Function<ExprNode, ExprNode> setAddNodeAsParent = node -> {
        if (!(node instanceof AddNode)) {
            return new AddNode(Arrays.asList(node));
        }
        return node;
    };
    private static Function<ExprNode, ExprNode> addConstantsTogether = addNode -> {
        if (!(addNode instanceof AddNode)) {
            return addNode;
        }
        List<ExprNode> sum = addNode.getChildren().stream().filter(n -> n instanceof IntNode).map(n -> ((IntNode)n).getValue()).reduce((a, b) -> a + b).map(v -> Arrays.asList(new IntNode((int)v))).orElse(Arrays.asList(new ExprNode[0]));
        List variables = addNode.getChildren().stream().filter(n -> !(n instanceof IntNode)).collect(Collectors.toList());
        List<ExprNode> addedFactors = Jutil.combineLists(sum, variables);
        return new AddNode(addedFactors);
    };
    private static Function<ExprNode, ExprNode> reduceMinusInMultiplication = multiplicationNode -> {
        List minusNodes = multiplicationNode.getChildren().stream().filter(n -> n instanceof MinusNode).collect(Collectors.toList());
        List nonMinusNodes = multiplicationNode.getChildren().stream().filter(n -> !(n instanceof MinusNode)).collect(Collectors.toList());
        if (minusNodes.size() > 1) {
            if (minusNodes.size() % 2 == 0) {
                List complemented = minusNodes.stream().map(n -> complement.apply((ExprNode)n)).collect(Collectors.toList());
                return new MultiplicationNode(Jutil.combineLists(nonMinusNodes, complemented));
            }
            List complementedButFirst = minusNodes.stream().skip(1L).map(n -> complement.apply((ExprNode)n)).collect(Collectors.toList());
            return new MultiplicationNode(Jutil.combineLists(Arrays.asList((ExprNode)minusNodes.get(0)), nonMinusNodes, complementedButFirst));
        }
        if (minusNodes.size() == 1 && nonMinusNodes.stream().anyMatch(n -> n instanceof IntNode)) {
            if (nonMinusNodes.stream().filter(n -> n instanceof IntNode).count() != 1L) {
                return multiplicationNode;
            }
            ExprNode complementedIntNode = nonMinusNodes.stream().filter(n -> n instanceof IntNode).map(complement).findFirst().get();
            List rest = nonMinusNodes.stream().filter(n -> !(n instanceof IntNode)).collect(Collectors.toList());
            return new MultiplicationNode(Jutil.combineLists(Arrays.asList(complementedIntNode), rest, Arrays.asList(complement.apply((ExprNode)minusNodes.get(0)))));
        }
        return multiplicationNode;
    };
    private static Function<ExprNode, ExprNode> multiplyIntegersTogether = multiplicationNode -> {
        List intNodes = multiplicationNode.getChildren().stream().filter(n -> n instanceof IntNode).collect(Collectors.toList());
        if (intNodes.size() > 1) {
            int multiplicationResult = intNodes.stream().map(n -> ((IntNode)n).getValue()).reduce(1, (a, b) -> a * b);
            IntNode resultNode = new IntNode(multiplicationResult);
            List rest = multiplicationNode.getChildren().stream().filter(n -> !(n instanceof IntNode)).collect(Collectors.toList());
            if (rest.size() == 0) {
                return resultNode;
            }
            return new MultiplicationNode(Jutil.combineLists(rest, Arrays.asList(resultNode)));
        }
        return multiplicationNode;
    };
    private static BiFunction<ExprNode, Optional<ExprNode>, ExprNode> transformMultiplications = (multiplicationNode, addNode) -> {
        if (!addNode.isPresent()) {
            return multiplicationNode;
        }
        List rest = multiplicationNode.getChildren().stream().filter(n -> !n.equals((ExprNode)addNode.get())).collect(Collectors.toList());
        List<ExprNode> factors = ((ExprNode)addNode.get()).getChildren().stream().map(n -> new MultiplicationNode(Jutil.combineLists(Arrays.asList(n), rest))).collect(Collectors.toList());
        return new AddNode(factors);
    };
    public static Function<ExprNode, ExprNode> reduceMultiplication = multiplicationNode -> {
        if (!(multiplicationNode instanceof MultiplicationNode)) {
            return Jutil.mapChildren.apply((ExprNode)multiplicationNode, reduceMultiplication);
        }
        Optional<ExprNode> addNode = multiplicationNode.getChildren().stream().filter(n -> n instanceof AddNode).findFirst();
        ExprNode transformed = transformMultiplications.andThen(multiplyIntegersTogether).andThen(reduceMinusInMultiplication).apply((ExprNode)multiplicationNode, addNode);
        return LogicalSolver.combineNodes(transformed, MultiplicationNode.class);
    };
    public static Function<ExprNode, ExprNode> reduceExpression = node -> {
        ExprNode reducted = addConstantsTogether.andThen(reduceMultiplication).andThen(LogicalSolver.reduceAdds).andThen(reduceComplement).apply((ExprNode)node);
        if (!reducted.equals((ExprNode)node)) {
            return reduceExpression.apply(reducted);
        }
        return node;
    };
    public static BiFunction<ExprNode, ExprNode, ExprNode> solveEquation = (equation, inRespective) -> {
        ExprNode leftSide = reduceExpression.andThen(setAddNodeAsParent).apply(equation.getChildren().get(0));
        List leftSideValues = leftSide.getChildren().stream().filter(n -> ArithmeticSolver.equalsOrComplementEquals(n, inRespective)).collect(Collectors.toList());
        List leftSideConstants = leftSide.getChildren().stream().filter(n -> !ArithmeticSolver.equalsOrComplementEquals(n, inRespective)).map(complement).collect(Collectors.toList());
        ExprNode rightSide = reduceExpression.andThen(setAddNodeAsParent).apply(equation.getChildren().get(1));
        List rightSideValues = rightSide.getChildren().stream().filter(n -> ArithmeticSolver.equalsOrComplementEquals(n, inRespective)).map(complement).collect(Collectors.toList());
        List rightSideConstants = rightSide.getChildren().stream().filter(n -> !ArithmeticSolver.equalsOrComplementEquals(n, inRespective)).collect(Collectors.toList());
        ExprNode constants = reduceExpression.apply(new AddNode(Jutil.combineLists(leftSideConstants, rightSideConstants)));
        ExprNode values = reduceExpression.apply(new AddNode(Jutil.combineLists(leftSideValues, rightSideValues)));
        return ArithmeticSolver.reduceValuesAndConstants(equation, constants, values);
    };
    public static Function<ExprNode, CallParam> solvedEquationToCallParam = equation -> {
        Integer location = ((LocalNode)equation.getChildren().get(0)).getValue();
        Integer solvedValueFromEquation = ((IntNode)equation.getChildren().get(1)).getValue();
        int callParamValue = solvedValueFromEquation + ArithmeticSolver.getValueToSatisfyEquation(equation);
        return new CallParam(new PrimitiveInstance(callParamValue), location);
    };

    public static boolean isArithmeticEquation(ExprNode exprNode) {
        return Jutil.containsClazz(exprNode, AddNode.class) || Jutil.containsClazz(exprNode, MinusNode.class) || Jutil.containsClazz(exprNode, GreaterThanEqNode.class) || Jutil.containsClazz(exprNode, GreaterThanNode.class) || Jutil.containsClazz(exprNode, LessThanNode.class) || Jutil.containsClazz(exprNode, IntNode.class) && !InstanceOfSolver.isInstanceOfEquation(exprNode);
    }

    public static Function<ExprNode, LocalNode> getSmallestLocationNode() {
        return equation -> {
            LocalNode smallesLocationNode = (LocalNode)Jutil.findAllNodeTypes(LocalNode.class).apply((ExprNode)equation).stream().sorted((a, b) -> Integer.compare(((LocalNode)a).getValue(), ((LocalNode)b).getValue())).findFirst().get();
            return smallesLocationNode;
        };
    }

    public static Function<ExprNode, ExprNode> solveEquation() {
        return equation -> {
            ExprNode solveInrespective = ArithmeticSolver.getSmallestLocationNode().apply((ExprNode)equation);
            return solveEquation.apply((ExprNode)equation, solveInrespective);
        };
    }

    private static boolean equalsOrComplementEquals(ExprNode node, ExprNode nodeToFind) {
        if (node.equals(nodeToFind) || complement.apply(node).equals(nodeToFind)) {
            return true;
        }
        if (node.getChildren().size() == 0) {
            return false;
        }
        return node.getChildren().stream().map(n -> ArithmeticSolver.equalsOrComplementEquals(n, nodeToFind)).reduce(false, (a, b) -> a != false || b != false);
    }

    public static Function<ExprNode, ExprNode> getFunctionsToReduceValues(ExprNode values) {
        if (values instanceof MinusNode) {
            return complement;
        }
        if (values instanceof LocalNode) {
            return node -> node;
        }
        if (values instanceof MultiplicationNode) {
            int i = ((IntNode)values.getChildren().stream().filter(n -> n instanceof IntNode).findFirst().get()).getValue();
            return node -> divide.apply((ExprNode)node, i);
        }
        if (values instanceof DivisionNode) {
            int i = ((IntNode)values.getChildren().stream().filter(n -> n instanceof IntNode).findFirst().get()).getValue();
            return node -> multiply.apply((ExprNode)node, i);
        }
        return node -> node;
    }

    private static ExprNode reduceValuesAndConstants(ExprNode equation, ExprNode constants, ExprNode values) {
        for (int i = 0; i < 5; ++i) {
            Function<ExprNode, ExprNode> reduceFunctions = ArithmeticSolver.getFunctionsToReduceValues(values);
            values = reduceFunctions.apply(values);
            constants = reduceFunctions.apply(constants);
            if (!(values instanceof LocalNode) && !(values instanceof InvokeVirtualNode) && !(values instanceof GetFieldNode)) continue;
            return Jutil.createNode.apply(equation.getClass(), Arrays.asList(values, constants));
        }
        throw new RuntimeException("Cannot reduce expression");
    }

    private static int getValueToSatisfyEquation(ExprNode equationRoot) {
        if (equationRoot instanceof EqualNode) {
            return 0;
        }
        if (equationRoot instanceof NotEqualNode) {
            return new Random().nextInt();
        }
        if (equationRoot instanceof LessThanNode) {
            return -1;
        }
        if (equationRoot instanceof LessThanEqNode) {
            return -1;
        }
        if (equationRoot instanceof GreaterThanNode) {
            return 1;
        }
        if (equationRoot instanceof GreaterThanEqNode) {
            return 1;
        }
        return 0;
    }
}

