/*
 * Decompiled with CFR 0.152.
 */
package org.tsers.junitquest;

import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.tsers.junitquest.CallParam;
import org.tsers.junitquest.Jutil;
import org.tsers.junitquest.expr.AndNode;
import org.tsers.junitquest.expr.EqualNode;
import org.tsers.junitquest.expr.ExprNode;
import org.tsers.junitquest.expr.GetFieldNode;
import org.tsers.junitquest.expr.IntNode;
import org.tsers.junitquest.expr.InvokeMethodNode;
import org.tsers.junitquest.expr.LocalNode;
import org.tsers.junitquest.solver.ArithmeticSolver;
import org.tsers.junitquest.solver.InstanceOfSolver;
import org.tsers.junitquest.solver.LogicalSolver;
import org.tsers.junitquest.solver.NullEquationSolver;

public class Solver {
    public static Function<ExprNode, ExprNode> transformMethodNode = node -> {
        Object c;
        if (node instanceof InvokeMethodNode && ((InvokeMethodNode)node).getReturnValue().isPresent() && (c = ((InvokeMethodNode)node).getReturnValue().get()) instanceof Integer) {
            return new IntNode((Integer)c);
        }
        return node;
    };
    public static Function<ExprNode, ExprNode> transformFieldNode = node -> {
        Object c;
        if (node instanceof GetFieldNode && ((GetFieldNode)node).getReturnValue().isPresent() && (c = ((GetFieldNode)node).getReturnValue().get()) instanceof Integer) {
            return new IntNode((Integer)c);
        }
        return node;
    };

    public static List<List<CallParam>> solveAll(ExprNode node) {
        List<List<CallParam>> cp1 = Solver.solve(node);
        ExprNode exp2 = Jutil.applyRecursively(transformMethodNode).andThen(Jutil.applyRecursively(transformFieldNode)).apply(node);
        List<List<CallParam>> cp2 = Solver.solve(exp2);
        return Jutil.combineLists(cp1, cp2);
    }

    public static List<List<CallParam>> solve(ExprNode node) {
        try {
            List<ExprNode> initialEquations = Solver.getEquations(node);
            List solvedArithmeticEquations = initialEquations.stream().filter(n -> ArithmeticSolver.isArithmeticEquation(n)).map(e -> Solver.replaceExtraVariablesWithConstants(e)).flatMap(e -> e.stream()).map(ArithmeticSolver.solveEquation()).map(ArithmeticSolver.solvedEquationToCallParam).collect(Collectors.toList());
            List solvedInstanceOfEquations = initialEquations.stream().filter(n -> InstanceOfSolver.isInstanceOfEquation(n)).map(InstanceOfSolver.solveInstanceOfEquation()).flatMap(s -> s.stream()).collect(Collectors.toList());
            List solvedNullEquations = initialEquations.stream().filter(e -> NullEquationSolver.isNullEquation(e)).map(NullEquationSolver.nullEquationToCallParam).collect(Collectors.toList());
            List allCallParams = Jutil.combineLists(solvedArithmeticEquations, solvedInstanceOfEquations, solvedNullEquations);
            List grouped = allCallParams.stream().collect(Collectors.groupingBy(c -> c.getPosition())).entrySet().stream().map(e -> (List)e.getValue()).map(c -> new HashSet(c)).collect(Collectors.toList());
            ArrayList<List<CallParam>> products = new ArrayList<List<CallParam>>(Sets.cartesianProduct(grouped));
            return products;
        }
        catch (Exception e2) {
            return Arrays.asList(new List[0]);
        }
    }

    private static List<ExprNode> getEquations(ExprNode node) {
        return LogicalSolver.combineANDs.apply(new AndNode(Arrays.asList(node))).getChildren();
    }

    public static List<ExprNode> replaceExtraVariablesWithConstants(ExprNode node) {
        Map<Integer, List<ExprNode>> groupedByLocals = Jutil.findAllNodeTypes(LocalNode.class).apply(node).stream().collect(Collectors.groupingBy(n -> ((LocalNode)n).getValue()));
        if (groupedByLocals.keySet().size() > 1) {
            int random = new Random().nextInt();
            List<Integer> replacedInts = groupedByLocals.keySet().stream().skip(1L).collect(Collectors.toList());
            List<ExprNode> eqs = Solver.createOriginalReplaceableEquations(random, replacedInts);
            List<ExprNode> replaced = Solver.createReplaceEquations(node, random, replacedInts);
            for (int limit = 0; limit < 10 && Solver.hasMultipleVariables(replaced); ++limit) {
                replaced = replaced.stream().map(e -> {
                    ExprNode r = e;
                    for (Integer i : replacedInts) {
                        r = Jutil.applyRecursively(Solver.localNodeToIntNode(i, random)).apply(r);
                    }
                    return r;
                }).collect(Collectors.toList());
            }
            return Jutil.combineLists(replaced, eqs);
        }
        return Arrays.asList(node);
    }

    private static List<ExprNode> createReplaceEquations(ExprNode node, int random, List<Integer> replacedInts) {
        return replacedInts.stream().map(i -> Arrays.asList(Jutil.applyRecursively(Solver.localNodeToIntNode(i, 0)).apply(node), Jutil.applyRecursively(Solver.localNodeToIntNode(i, 1)).apply(node), Jutil.applyRecursively(Solver.localNodeToIntNode(i, random)).apply(node))).flatMap(e -> e.stream()).collect(Collectors.toList());
    }

    private static List<ExprNode> createOriginalReplaceableEquations(int random, List<Integer> replacedInts) {
        return replacedInts.stream().map(i -> Arrays.asList(new EqualNode(Arrays.asList(new LocalNode((int)i), new IntNode(random))), new EqualNode(Arrays.asList(new LocalNode((int)i), new IntNode(0))), new EqualNode(Arrays.asList(new LocalNode((int)i), new IntNode(1))))).flatMap(e -> e.stream()).collect(Collectors.toList());
    }

    private static boolean hasMultipleVariables(List<ExprNode> node) {
        return node.stream().map(e -> Solver.hasMultipleVariables(e)).reduce(false, (a, b) -> a != false || b != false);
    }

    private static boolean hasMultipleVariables(ExprNode node) {
        Map<Integer, List<ExprNode>> o = Jutil.findAllNodeTypes(LocalNode.class).apply(node).stream().collect(Collectors.groupingBy(n -> ((LocalNode)n).getValue()));
        return o.keySet().size() > 1;
    }

    private static Function<ExprNode, ExprNode> localNodeToIntNode(Integer localValue, Integer intValue) {
        return node -> {
            if (node instanceof LocalNode && ((LocalNode)node).getValue() == localValue.intValue()) {
                return new IntNode(intValue);
            }
            return node;
        };
    }
}

