diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java index 8462270e31e..edc2f617a7f 100644 --- a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java @@ -52,7 +52,8 @@ import org.apache.log4j.Logger; */ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { - private static final int thinkTimeRatioThreshold = 20; + private static final int THINK_MIN_RATIO = 20; + private static final int THINK_MAX_RATIO = 100; protected transient MCTSNode root; protected int maxThinkTime; @@ -112,8 +113,10 @@ public class ComputerPlayerMCTS extends ComputerPlayer imple if (root != null) { MCTSNode newRoot = null; newRoot = root.getMatchingState(game.getState().getValue(false, game)); - if (newRoot != null) + if (newRoot != null) { newRoot.emancipate(); + logger.info("choose action:" + newRoot.getAction() + " success ratio: " + newRoot.getWinRatio()); + } else logger.info("unable to find matching state"); root = newRoot; @@ -258,28 +261,58 @@ public class ComputerPlayerMCTS extends ComputerPlayer imple logger.info("applyMCTS - Thinking for " + (endTime - startTime)/1000000000.0 + "s"); if (thinkTime > 0) { - List tasks = new ArrayList(); - for (int i = 0; i < cores; i++) { - Game sim = createMCTSGame(game); - MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId); - player.setNextAction(action); - MCTSExecutor exec = new MCTSExecutor(sim, playerId, thinkTime); - tasks.add(exec); +// List tasks = new ArrayList(); +// for (int i = 0; i < cores; i++) { +// Game sim = createMCTSGame(game); +// MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId); +// player.setNextAction(action); +// MCTSExecutor exec = new MCTSExecutor(sim, playerId, thinkTime); +// tasks.add(exec); +// } +// +// try { +// pool.invokeAll(tasks); +// } catch (InterruptedException ex) { +// logger.warn("applyMCTS interrupted"); +// } +// +// for (MCTSExecutor task: tasks) { +// root.merge(task.getRoot()); +// task.clear(); +// } +// tasks.clear(); + + MCTSNode current; + int simCount = 0; + while (true) { + long currentTime = System.nanoTime(); + if (currentTime > endTime) + break; + current = root; + + // Selection + while (!current.isLeaf()) { + current = current.select(this.playerId); + } + + int result; + if (!current.isTerminal()) { + // Expansion + current.expand(); + + // Simulation + current = current.select(this.playerId); + result = current.simulate(this.playerId); + simCount++; + } + else { + result = current.isWinner(this.playerId)?1:-1; + } + // Backpropagation + current.backpropagate(result); } - try { - pool.invokeAll(tasks); - } catch (InterruptedException ex) { - logger.warn("applyMCTS interrupted"); - } - - for (MCTSExecutor task: tasks) { - root.merge(task.getRoot()); - task.clear(); - } - tasks.clear(); - - logger.info("Created " + root.getNodeCount() + " nodes - size: " + root.size()); + logger.info("Simulated " + simCount + " games - nodes in tree: " + root.size()); displayMemory(); } @@ -287,32 +320,38 @@ public class ComputerPlayerMCTS extends ComputerPlayer imple return; } - //try to ensure that there are at least 20 simulations per node at all times + //try to ensure that there are at least THINK_MIN_RATIO simulations per node at all times private int calculateThinkTime(Game game, NextAction action) { int thinkTime = 0; int nodeSizeRatio = 0; if (root.getNumChildren() > 0) - nodeSizeRatio = root.size() / root.getNumChildren(); + nodeSizeRatio = root.getVisits() / root.getNumChildren(); logger.info("Ratio: " + nodeSizeRatio); PhaseStep curStep = game.getStep().getType(); if (action == NextAction.SELECT_ATTACKERS || action == NextAction.SELECT_BLOCKERS) { - if (nodeSizeRatio < thinkTimeRatioThreshold) { + if (nodeSizeRatio < THINK_MIN_RATIO) { thinkTime = maxThinkTime; } + else if (nodeSizeRatio >= THINK_MAX_RATIO) { + thinkTime = 0; + } else { thinkTime = maxThinkTime / 2; } } - else if (game.getActivePlayerId().equals(playerId) && (curStep == PhaseStep.PRECOMBAT_MAIN || curStep == PhaseStep.POSTCOMBAT_MAIN)) { - if (nodeSizeRatio < thinkTimeRatioThreshold) { + else if (game.getActivePlayerId().equals(playerId) && (curStep == PhaseStep.PRECOMBAT_MAIN || curStep == PhaseStep.POSTCOMBAT_MAIN) && game.getStack().isEmpty()) { + if (nodeSizeRatio < THINK_MIN_RATIO) { thinkTime = maxThinkTime; } + else if (nodeSizeRatio >= THINK_MAX_RATIO) { + thinkTime = 0; + } else { thinkTime = maxThinkTime / 2; } } else { - if (nodeSizeRatio < thinkTimeRatioThreshold) { + if (nodeSizeRatio < THINK_MIN_RATIO) { thinkTime = maxThinkTime / 2; } else { diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java index 578c18176e0..d339e3acb80 100644 --- a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java @@ -61,7 +61,6 @@ public class MCTSNode { private MCTSNode parent; private List children = new ArrayList(); private Ability action; -// private Combat combat; private Game game; private String stateValue; private UUID playerId; @@ -88,7 +87,6 @@ public class MCTSNode { this.game = game; this.stateValue = game.getState().getValue(false, game); this.parent = parent; -// this.combat = game.getCombat(); setPlayer(); nodeCount++; } @@ -140,7 +138,6 @@ public class MCTSNode { List abilities = player.getPlayableOptions(game); for (Ability ability: abilities) { Game sim = game.copy(); -// String simState = sim.getState().getValue(false, sim); // logger.info("expand " + ability.toString()); MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId()); simPlayer.activateAbility((ActivatedAbility)ability, sim); @@ -154,7 +151,6 @@ public class MCTSNode { UUID defenderId = game.getOpponents(player.getId()).iterator().next(); for (List attack: attacks) { Game sim = game.copy(); -// String simState = sim.getState().getValue(false, sim); MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId()); for (UUID attackerId: attack) { simPlayer.declareAttacker(attackerId, defenderId, sim); @@ -168,7 +164,6 @@ public class MCTSNode { List>> blocks = player.getBlocks(game); for (List> block: blocks) { Game sim = game.copy(); -// String simState = sim.getState().getValue(false, sim); MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId()); List groups = sim.getCombat().getGroups(); for (int i = 0; i < groups.size(); i++) { @@ -248,7 +243,10 @@ public class MCTSNode { } public void emancipate() { - this.parent = null; + if (parent != null) { + this.parent.children.remove(this); + this.parent = null; + } } public Ability getAction() { @@ -275,6 +273,16 @@ public class MCTSNode { return stateValue; } + public double getWinRatio() { + if (visits > 0) + return wins/(visits * 1.0); + return -1.0; + } + + public int getVisits() { + return visits; + } + /** * Copies game and replaces all players in copy with simulated players * Shuffles each players library so that there is no knowledge of its order @@ -289,26 +297,37 @@ public class MCTSNode { Player origPlayer = game.getState().getPlayers().get(copyPlayer.getId()).copy(); SimulatedPlayerMCTS newPlayer = new SimulatedPlayerMCTS(copyPlayer.getId(), true); newPlayer.restore(origPlayer); - if (!newPlayer.getId().equals(playerId)) { - int handSize = newPlayer.getHand().size(); - newPlayer.getLibrary().addAll(newPlayer.getHand().getCards(sim), sim); - newPlayer.getHand().clear(); - newPlayer.getLibrary().shuffle(); - for (int i = 0; i < handSize; i++) { - Card card = newPlayer.getLibrary().removeFromTop(sim); - sim.setZone(card.getId(), Zone.HAND); - newPlayer.getHand().add(card); - } - } - else { - newPlayer.getLibrary().shuffle(); - } sim.getState().getPlayers().put(copyPlayer.getId(), newPlayer); } + randomizePlayers(sim, playerId); sim.setSimulation(true); return sim; } + /* + * Shuffles each players library so that there is no knowledge of its order + * Swaps all other players hands with random cards from the library so that + * there is no knowledge of what cards are in opponents hands + */ + protected void randomizePlayers(Game game, UUID playerId) { + for (Player player: game.getState().getPlayers().values()) { + if (!player.getId().equals(playerId)) { + int handSize = player.getHand().size(); + player.getLibrary().addAll(player.getHand().getCards(game), game); + player.getHand().clear(); + player.getLibrary().shuffle(); + for (int i = 0; i < handSize; i++) { + Card card = player.getLibrary().removeFromTop(game); + game.setZone(card.getId(), Zone.HAND); + player.getHand().add(card); + } + } + else { + player.getLibrary().shuffle(); + } + } + } + public boolean isTerminal() { return game.isGameOver(); } diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSPlayer.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSPlayer.java index 24e482a8166..0760ba32440 100644 --- a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSPlayer.java +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSPlayer.java @@ -28,16 +28,13 @@ package mage.player.ai; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.UUID; import mage.abilities.Ability; import mage.abilities.SpellAbility; import mage.abilities.common.PassAbility; import mage.abilities.costs.mana.GenericManaCost; import mage.game.Game; -import mage.game.combat.Combat; import mage.game.permanent.Permanent; import org.apache.log4j.Logger; diff --git a/Mage.Server/plugins/mage-player-aimcts.jar b/Mage.Server/plugins/mage-player-aimcts.jar index 65749101e25..b6ab68d2a9b 100644 Binary files a/Mage.Server/plugins/mage-player-aimcts.jar and b/Mage.Server/plugins/mage-player-aimcts.jar differ