diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java index f2599117912..1f7a71bdb98 100644 --- a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java @@ -27,15 +27,21 @@ */ package mage.player.ai; +import java.util.ArrayList; import java.util.List; import java.util.UUID; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.logging.Level; import mage.Constants.RangeOfInfluence; import mage.abilities.Ability; import mage.abilities.ActivatedAbility; import mage.game.Game; import mage.game.combat.Combat; import mage.game.combat.CombatGroup; -import mage.game.permanent.Permanent; +import mage.player.ai.MCTSPlayer.NextAction; import mage.players.Player; import org.apache.log4j.Logger; @@ -48,11 +54,15 @@ public class ComputerPlayerMCTS extends ComputerPlayer imple protected transient MCTSNode root; protected int thinkTime; private final static transient Logger logger = Logger.getLogger(ComputerPlayerMCTS.class); + private ExecutorService pool; + private int cores; public ComputerPlayerMCTS(String name, RangeOfInfluence range, int skill) { super(name, range); human = false; thinkTime = skill; + cores = Runtime.getRuntime().availableProcessors(); + pool = Executors.newFixedThreadPool(cores); } protected ComputerPlayerMCTS(UUID id) { @@ -70,32 +80,32 @@ public class ComputerPlayerMCTS extends ComputerPlayer imple @Override public void priority(Game game) { - getNextAction(game); + getNextAction(game, NextAction.PRIORITY); Ability ability = root.getAction(); if (ability == null) logger.fatal("null ability"); activateAbility((ActivatedAbility)ability, game); } - protected void calculateActions(Game game) { + protected void calculateActions(Game game, NextAction action) { if (root == null) { Game sim = createMCTSGame(game); MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId); - player.setNextAction(MCTSPlayer.NextAction.PRIORITY); + player.setNextAction(action); root = new MCTSNode(sim); } - applyMCTS(); + applyMCTS(game, action); root = root.bestChild(); root.emancipate(); } - protected void getNextAction(Game game) { + protected void getNextAction(Game game, NextAction nextAction) { if (root != null) { - root = root.getMatchingState(game.getState().getValue().hashCode()); + root = root.getMatchingState(game.getState().getValue().hashCode(), nextAction); if (root != null) root.emancipate(); } - calculateActions(game); + calculateActions(game, nextAction); } // @Override @@ -171,10 +181,11 @@ public class ComputerPlayerMCTS extends ComputerPlayer imple @Override public void selectAttackers(Game game) { Game sim = createMCTSGame(game); - MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId); - player.setNextAction(MCTSPlayer.NextAction.SELECT_ATTACKERS); - root = new MCTSNode(sim); - applyMCTS(); + getNextAction(sim, NextAction.SELECT_ATTACKERS); +// MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId); +// player.setNextAction(MCTSPlayer.NextAction.SELECT_ATTACKERS); +// root = new MCTSNode(sim); +// applyMCTS(); Combat combat = root.bestChild().getCombat(); UUID opponentId = game.getCombat().getDefenders().iterator().next(); for (UUID attackerId: combat.getAttackers()) { @@ -185,10 +196,11 @@ public class ComputerPlayerMCTS extends ComputerPlayer imple @Override public void selectBlockers(Game game) { Game sim = createMCTSGame(game); - MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId); - player.setNextAction(MCTSPlayer.NextAction.SELECT_BLOCKERS); - root = new MCTSNode(sim); - applyMCTS(); + getNextAction(sim, NextAction.SELECT_BLOCKERS); +// MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId); +// player.setNextAction(MCTSPlayer.NextAction.SELECT_BLOCKERS); +// root = new MCTSNode(sim); +// applyMCTS(); Combat combat = root.bestChild().getCombat(); List groups = game.getCombat().getGroups(); for (int i = 0; i < groups.size(); i++) { @@ -235,51 +247,34 @@ public class ComputerPlayerMCTS extends ComputerPlayer imple // throw new UnsupportedOperationException("Not supported yet."); // } - protected void applyMCTS() { + protected void applyMCTS(final Game game, final NextAction action) { long startTime = System.nanoTime(); long endTime = startTime + (thinkTime * 1000000000l); - MCTSNode current; - - if (root.getNumChildren() == 1) - //there is only one possible action - return; - logger.info("applyMCTS - Thinking for " + (endTime - startTime)/1000000000.0 + "s"); - while (true) { - long currentTime = System.nanoTime(); - logger.info("Remaining time: " + (endTime - currentTime)/1000000000.0 + "s"); - if (currentTime > endTime) - break; - current = root; - - // Selection - while (!current.isLeaf()) { - current = current.select(); - } - - int result; - if (!current.isTerminal()) { - // Expansion - current.expand(); - - if (current == root && current.getNumChildren() == 1) - //there is only one possible action - return; - - // Simulation - current = current.select(); - result = current.simulate(this.playerId); - } - else { - result = current.isWinner(this.playerId)?1:0; - } - // Backpropagation - current.backpropagate(result); + + List tasks = new ArrayList(); + for (int i = 0; i < cores; i++) { + Game sim = createMCTSGame(game); + MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId); + player.setNextAction(action); + MCTSExecutor exec = new MCTSExecutor(sim, playerId, thinkTime); + tasks.add(exec); } + + try { + pool.invokeAll(tasks); + } catch (InterruptedException ex) { + logger.warn("applyMCTS interrupted"); + } + + for (MCTSExecutor task: tasks) { + root.merge(task.getRoot()); + } + logger.info("Created " + root.getNodeCount() + " nodes"); return; } - + /** * Copies game and replaces all players in copy with mcts players * Shuffles each players library so that there is no knowledge of its order diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java index 5788c684caa..693c5bece98 100644 --- a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java @@ -38,6 +38,7 @@ import mage.game.GameState; import mage.game.combat.Combat; import mage.game.combat.CombatGroup; import mage.game.turn.Step.StepPart; +import mage.player.ai.MCTSPlayer.NextAction; import mage.players.Player; import org.apache.log4j.Logger; @@ -122,12 +123,12 @@ public class MCTSNode { } switch (player.getNextAction()) { case PRIORITY: - logger.info("Priority for player:" + player.getName() + " turn: " + game.getTurnNum() + " phase: " + game.getPhase().getType() + " step: " + game.getStep().getType()); +// logger.info("Priority for player:" + player.getName() + " turn: " + game.getTurnNum() + " phase: " + game.getPhase().getType() + " step: " + game.getStep().getType()); List abilities = player.getPlayableOptions(game); for (Ability ability: abilities) { Game sim = game.copy(); int simState = sim.getState().getValue().hashCode(); - logger.info("expand " + ability.toString()); +// logger.info("expand " + ability.toString()); MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId()); simPlayer.activateAbility((ActivatedAbility)ability, sim); sim.resume(); @@ -135,7 +136,7 @@ public class MCTSNode { } break; case SELECT_ATTACKERS: - logger.info("Select attackers:" + player.getName()); +// logger.info("Select attackers:" + player.getName()); List> attacks = player.getAttacks(game); UUID defenderId = game.getOpponents(player.getId()).iterator().next(); for (List attack: attacks) { @@ -150,7 +151,7 @@ public class MCTSNode { } break; case SELECT_BLOCKERS: - logger.info("Select blockers:" + player.getName()); +// logger.info("Select blockers:" + player.getName()); List>> blocks = player.getBlocks(game); for (List> block: blocks) { Game sim = game.copy(); @@ -178,9 +179,9 @@ public class MCTSNode { long duration = System.nanoTime() - startTime; int retVal = 0; for (Player simPlayer: sim.getPlayers().values()) { - logger.info(simPlayer.getName() + " calculated " + ((SimulatedPlayerMCTS)simPlayer).getActionCount() + " actions in " + duration/1000000000.0 + "s"); +// logger.info(simPlayer.getName() + " calculated " + ((SimulatedPlayerMCTS)simPlayer).getActionCount() + " actions in " + duration/1000000000.0 + "s"); if (simPlayer.getId().equals(playerId) && simPlayer.hasWon()) { - logger.info("AI won the simulation"); +// logger.info("AI won the simulation"); retVal = 1; } } @@ -267,18 +268,49 @@ public class MCTSNode { return false; } - public MCTSNode getMatchingState(int state) { + public MCTSNode getMatchingState(int state, NextAction nextAction) { for (MCTSNode node: children) { -// logger.info(state); -// logger.info(node.stateValue); if (node.stateValue == state && node.action != null) { - return node; + MCTSPlayer player; + if (game.getStep().getStepPart() == StepPart.PRIORITY) + player = (MCTSPlayer) game.getPlayer(game.getPriorityPlayerId()); + else { + if (game.getStep().getType() == PhaseStep.DECLARE_BLOCKERS) + player = (MCTSPlayer) game.getPlayer(game.getCombat().getDefenders().iterator().next()); + else + player = (MCTSPlayer) game.getPlayer(game.getActivePlayerId()); + } + if (player.getNextAction() == nextAction) + return node; } - MCTSNode match = node.getMatchingState(state); + MCTSNode match = node.getMatchingState(state, nextAction); if (match != null) return node; } return null; } + public void merge(MCTSNode merge) { + this.visits += merge.visits; + this.wins += merge.wins; + + List mergeChildren = new ArrayList(); + for (MCTSNode child: merge.children) { + mergeChildren.add(child); + } + + for (MCTSNode child: children) { + for (MCTSNode mergeChild: mergeChildren) { + if (mergeChild.stateValue == child.stateValue) { + child.merge(mergeChild); + mergeChildren.remove(mergeChild); + break; + } + } + } + if (!mergeChildren.isEmpty()) { + children.addAll(mergeChildren); + } + } + } diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/SimulatedPlayerMCTS.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/SimulatedPlayerMCTS.java index 0d565b0ced2..0be830f21a9 100644 --- a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/SimulatedPlayerMCTS.java +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/SimulatedPlayerMCTS.java @@ -265,7 +265,7 @@ public class SimulatedPlayerMCTS extends MCTSPlayer { public boolean choose(Outcome outcome, Cards cards, TargetCard target, Game game) { if (cards.isEmpty()) return !target.isRequired(); - Set possibleTargets = target.possibleTargets(playerId, game); + Set possibleTargets = target.possibleTargets(playerId, cards, game); if (possibleTargets.isEmpty()) return !target.isRequired(); Iterator it = possibleTargets.iterator();