diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNextActionFactory.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNextActionFactory.java new file mode 100644 index 00000000000..fb408fafa34 --- /dev/null +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNextActionFactory.java @@ -0,0 +1,21 @@ +package mage.player.ai; + +import java.util.HashMap; + +public class MCTSNextActionFactory { + private static final HashMap strategyMap = new HashMap<>(); + + static { + strategyMap.put(MCTSPlayer.NextAction.PRIORITY, new PriorityNextAction()); + strategyMap.put(MCTSPlayer.NextAction.SELECT_BLOCKERS, new SelectBlockersNextAction()); + strategyMap.put(MCTSPlayer.NextAction.SELECT_ATTACKERS, new SelectAttackersNextAction()); + } + + public static MCTSNodeNextAction createNextAction(MCTSPlayer.NextAction nextAction) { + MCTSNodeNextAction strategy = strategyMap.get(nextAction); + if (strategy == null) { + throw new IllegalArgumentException("Unsupported action: " + nextAction); + } + return strategy; + } +} diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java index 467ca4ea562..463a6cce7c9 100644 --- a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java @@ -127,64 +127,7 @@ public class MCTSNode { if (player.getNextAction() == null) { logger.fatal("next action is null"); } - switch (player.getNextAction()) { - case PRIORITY: -// logger.info("Priority for player:" + player.getName() + " turn: " + game.getTurnNum() + " phase: " + game.getTurnPhaseType() + " step: " + game.getTurnStepType()); - List abilities; - if (!USE_ACTION_CACHE) - abilities = player.getPlayableOptions(game); - else - abilities = getPlayables(player, fullStateValue, game); - for (Ability ability: abilities) { - Game sim = game.copy(); -// logger.info("expand " + ability.toString()); - MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId()); - simPlayer.activateAbility((ActivatedAbility)ability, sim); - sim.resume(); - children.add(new MCTSNode(this, sim, ability)); - } - break; - case SELECT_ATTACKERS: -// logger.info("Select attackers:" + player.getName()); - List> attacks; - if (!USE_ACTION_CACHE) - attacks = player.getAttacks(game); - else - attacks = getAttacks(player, fullStateValue, game); - UUID defenderId = game.getOpponents(player.getId()).iterator().next(); - for (List attack: attacks) { - Game sim = game.copy(); - MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId()); - for (UUID attackerId: attack) { - simPlayer.declareAttacker(attackerId, defenderId, sim, false); - } - sim.resume(); - children.add(new MCTSNode(this, sim, sim.getCombat())); - } - break; - case SELECT_BLOCKERS: -// logger.info("Select blockers:" + player.getName()); - List>> blocks; - if (!USE_ACTION_CACHE) - blocks = player.getBlocks(game); - else - blocks = getBlocks(player, fullStateValue, game); - for (List> block: blocks) { - Game sim = game.copy(); - MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId()); - List groups = sim.getCombat().getGroups(); - for (int i = 0; i < groups.size(); i++) { - if (i < block.size()) { - for (UUID blockerId: block.get(i)) { - simPlayer.declareBlocker(simPlayer.getId(), blockerId, groups.get(i).getAttackers().get(0), sim); - } - } - } - sim.resume(); - children.add(new MCTSNode(this, sim, sim.getCombat())); - } - break; - } + children.addAll(MCTSNextActionFactory.createNextAction(player.getNextAction()).performNextAction(this, player, game, fullStateValue)); game = null; } @@ -462,8 +405,8 @@ public class MCTSNode { private static long attacksMiss = 0; private static long blocksHit = 0; private static long blocksMiss = 0; - - private static List getPlayables(MCTSPlayer player, String state, Game game) { + + protected static List getPlayables(MCTSPlayer player, String state, Game game) { if (playablesCache.containsKey(state)) { playablesHit++; return playablesCache.get(state); @@ -475,8 +418,8 @@ public class MCTSNode { return abilities; } } - - private static List> getAttacks(MCTSPlayer player, String state, Game game) { + + protected static List> getAttacks(MCTSPlayer player, String state, Game game) { if (attacksCache.containsKey(state)) { attacksHit++; return attacksCache.get(state); @@ -489,7 +432,7 @@ public class MCTSNode { } } - private static List>> getBlocks(MCTSPlayer player, String state, Game game) { + protected static List>> getBlocks(MCTSPlayer player, String state, Game game) { if (blocksCache.containsKey(state)) { blocksHit++; return blocksCache.get(state); diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNodeNextAction.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNodeNextAction.java new file mode 100644 index 00000000000..befa16c515b --- /dev/null +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNodeNextAction.java @@ -0,0 +1,9 @@ +package mage.player.ai; + +import mage.game.Game; + +import java.util.List; + +public interface MCTSNodeNextAction { + List performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue); +} diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/PriorityNextAction.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/PriorityNextAction.java new file mode 100644 index 00000000000..ac38b8969e3 --- /dev/null +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/PriorityNextAction.java @@ -0,0 +1,30 @@ +package mage.player.ai; + +import mage.abilities.Ability; +import mage.abilities.ActivatedAbility; +import mage.game.Game; + +import java.util.ArrayList; +import java.util.List; + +public class PriorityNextAction implements MCTSNodeNextAction{ + + @Override + public List performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue) { + List children = new ArrayList<>(); + List abilities; + if (!MCTSNode.USE_ACTION_CACHE) + abilities = player.getPlayableOptions(game); + else + abilities = MCTSNode.getPlayables(player, fullStateValue, game); + for (Ability ability: abilities) { + Game sim = game.copy(); + MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId()); + simPlayer.activateAbility((ActivatedAbility)ability, sim); + sim.resume(); + children.add(new MCTSNode(node, sim, ability)); + } + + return children; + } +} diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/SelectAttackersNextAction.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/SelectAttackersNextAction.java new file mode 100644 index 00000000000..3627dd74d92 --- /dev/null +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/SelectAttackersNextAction.java @@ -0,0 +1,33 @@ +package mage.player.ai; + +import mage.game.Game; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + +import static mage.player.ai.MCTSNode.getAttacks; + +public class SelectAttackersNextAction implements MCTSNodeNextAction{ + @Override + public List performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue) { + List children = new ArrayList<>(); + List> attacks; + if (!MCTSNode.USE_ACTION_CACHE) + attacks = player.getAttacks(game); + else + attacks = getAttacks(player, fullStateValue, game); + UUID defenderId = game.getOpponents(player.getId()).iterator().next(); + for (List attack: attacks) { + Game sim = game.copy(); + MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId()); + for (UUID attackerId: attack) { + simPlayer.declareAttacker(attackerId, defenderId, sim, false); + } + sim.resume(); + children.add(new MCTSNode(node, sim, sim.getCombat())); + } + + return children; + } +} diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/SelectBlockersNextAction.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/SelectBlockersNextAction.java new file mode 100644 index 00000000000..ec032fa32bf --- /dev/null +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/SelectBlockersNextAction.java @@ -0,0 +1,38 @@ +package mage.player.ai; + +import mage.game.Game; +import mage.game.combat.CombatGroup; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + +import static mage.player.ai.MCTSNode.getBlocks; + +public class SelectBlockersNextAction implements MCTSNodeNextAction{ + @Override + public List performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue) { + List children = new ArrayList<>(); + List>> blocks; + if (!MCTSNode.USE_ACTION_CACHE) + blocks = player.getBlocks(game); + else + blocks = getBlocks(player, fullStateValue, game); + for (List> block : blocks) { + Game sim = game.copy(); + MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId()); + List groups = sim.getCombat().getGroups(); + for (int i = 0; i < groups.size(); i++) { + if (i < block.size()) { + for (UUID blockerId : block.get(i)) { + simPlayer.declareBlocker(simPlayer.getId(), blockerId, groups.get(i).getAttackers().get(0), sim); + } + } + } + sim.resume(); + children.add(new MCTSNode(node, sim, sim.getCombat())); + } + + return children; + } +}