/*
 * Decompiled with CFR 0.152.
 */
package tlc2.tool;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.LongAdder;
import tlc2.tool.Action;
import tlc2.tool.ITool;
import tlc2.tool.SimulationWorker;
import tlc2.tool.Simulator;
import tlc2.tool.StateVec;
import tlc2.tool.TLCState;
import tlc2.tool.liveness.ILiveCheck;
import tlc2.util.RandomGenerator;

public class RLSimulationWorker
extends SimulationWorker {
    protected static final double ALPHA = Double.valueOf(System.getProperty(Simulator.class.getName() + ".rl.alpha", ".3d"));
    protected static final double GAMMA = Double.valueOf(System.getProperty(Simulator.class.getName() + ".rl.gamma", ".7d"));
    protected static final double REWARD = Double.valueOf(System.getProperty(Simulator.class.getName() + ".rl.reward", "-10d"));
    protected static final boolean ENABLED_ONLY = Boolean.getBoolean(Simulator.class.getName() + ".rl.enabledOnly");
    protected final Map<Action, Map<Long, Double>> q = new HashMap<Action, Map<Long, Double>>();

    public RLSimulationWorker(int id, ITool tool, BlockingQueue<SimulationWorker.SimulationWorkerResult> resultQueue, long seed, int maxTraceDepth, long maxTraceNum, boolean checkDeadlock, String traceFile, ILiveCheck liveCheck) {
        this(id, tool, resultQueue, seed, maxTraceDepth, maxTraceNum, null, checkDeadlock, traceFile, liveCheck, new LongAdder(), new AtomicLong(), new AtomicLong());
    }

    public RLSimulationWorker(int id, ITool tool, BlockingQueue<SimulationWorker.SimulationWorkerResult> resultQueue, long seed, int maxTraceDepth, long maxTraceNum, String traceActions, boolean checkDeadlock, String traceFile, ILiveCheck liveCheck, LongAdder numOfGenStates, AtomicLong numOfGenTraces, AtomicLong m2AndMean) {
        super(id, tool, resultQueue, seed, maxTraceDepth, maxTraceNum, traceActions, checkDeadlock, traceFile, liveCheck, numOfGenStates, numOfGenTraces, m2AndMean);
        for (Action a : tool.getActions()) {
            this.q.put(a, new HashMap());
        }
    }

    protected double getReward(TLCState s, Action a, TLCState t) {
        return this.tool.evalReward(s, t, REWARD);
    }

    private final double getMaxQ(long fp) {
        double max = -1.7976931348623157E308;
        for (Action a : this.q.keySet()) {
            double d = this.q.get(a).getOrDefault(fp, -1.7976931348623157E308);
            max = Math.max(max, d);
        }
        return max;
    }

    protected long getHash(TLCState state) {
        return state.fingerPrint();
    }

    @Override
    protected int getNextActionAltIndex(int index, int p, Action[] actions, TLCState curState) {
        if (!ENABLED_ONLY) {
            this.q.get(actions[index]).put(this.getHash(curState), -1.7976931348623157E308);
        }
        return super.getNextActionAltIndex(index, p, actions, curState);
    }

    @Override
    protected final int getNextActionIndex(RandomGenerator rng, Action[] actions, TLCState state) {
        long s = this.getHash(state);
        this.q.values().forEach(m -> m.putIfAbsent(s, 0.0));
        double denum = 0.0;
        double[] d = new double[actions.length];
        for (int i = 0; i < d.length; ++i) {
            d[i] = Math.exp(this.q.get(actions[i]).get(s));
            denum += d[i];
        }
        ArrayList<Pair> m2 = new ArrayList<Pair>(d.length);
        for (int i = 0; i < d.length; ++i) {
            m2.add(new Pair(i, d[i] / denum));
        }
        double nd = rng.nextDouble();
        Collections.sort(m2);
        for (int i = 0; i < d.length; ++i) {
            Pair p = (Pair)m2.get(i);
            double d2 = d[i] = i == 0 ? p.key : d[i - 1] + p.key;
            if (!(d[i] >= nd)) continue;
            return p.value;
        }
        return ((Pair)m2.get((int)(d.length - 1))).value;
    }

    @Override
    protected boolean postTrace(TLCState s) {
        int level = s.getLevel();
        for (int i = level - 1; i > 0; --i) {
            double maxQ = this.getMaxQ(this.getHash(s));
            TLCState p = s.getPredecessor();
            long fp = this.getHash(p);
            Action ai = s.getAction();
            double qi = this.q.get(ai).get(fp);
            double r = this.getReward(p, ai, s);
            double q = (1.0 - ALPHA) * qi + ALPHA * (r + GAMMA * maxQ);
            this.q.get(ai).put(fp, q);
            s = p;
        }
        return true;
    }

    @Override
    protected Action[] filterActions(Action[] actions, TLCState curState) throws SimulationWorker.SimulationWorkerError {
        if (ENABLED_ONLY) {
            ArrayList<Action> l = new ArrayList<Action>();
            for (int i = 0; i < actions.length; ++i) {
                StateVec ns = this.tool.getNextStates(actions[i], curState);
                if (ns.empty()) continue;
                l.add(actions[i]);
            }
            return (Action[])l.toArray(Action[]::new);
        }
        return super.filterActions(actions, curState);
    }

    private static class Pair
    implements Comparable<Pair> {
        public final double key;
        public final int value;

        public Pair(int v, double k) {
            this.key = k;
            this.value = v;
        }

        @Override
        public int compareTo(Pair o) {
            return Double.compare(o.key, this.key);
        }

        public String toString() {
            return "[key=" + this.key + ", value=" + this.value + "]";
        }
    }
}

