mirror of
https://github.com/magefree/mage.git
synced 2025-12-20 10:40:06 -08:00
refactor: improved usage of NextAction in mcts AI code (#11480)
Replaced conditional of selecting next action with runtime polymorphism, to increase the readability and easier for future changes by following Open/Close principle.
This commit is contained in:
parent
2ef9439773
commit
7913c01ec3
6 changed files with 137 additions and 63 deletions
|
|
@ -0,0 +1,21 @@
|
|||
package mage.player.ai;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
public class MCTSNextActionFactory {
|
||||
private static final HashMap<MCTSPlayer.NextAction, MCTSNodeNextAction> strategyMap = new HashMap<>();
|
||||
|
||||
static {
|
||||
strategyMap.put(MCTSPlayer.NextAction.PRIORITY, new PriorityNextAction());
|
||||
strategyMap.put(MCTSPlayer.NextAction.SELECT_BLOCKERS, new SelectBlockersNextAction());
|
||||
strategyMap.put(MCTSPlayer.NextAction.SELECT_ATTACKERS, new SelectAttackersNextAction());
|
||||
}
|
||||
|
||||
public static MCTSNodeNextAction createNextAction(MCTSPlayer.NextAction nextAction) {
|
||||
MCTSNodeNextAction strategy = strategyMap.get(nextAction);
|
||||
if (strategy == null) {
|
||||
throw new IllegalArgumentException("Unsupported action: " + nextAction);
|
||||
}
|
||||
return strategy;
|
||||
}
|
||||
}
|
||||
|
|
@ -127,64 +127,7 @@ public class MCTSNode {
|
|||
if (player.getNextAction() == null) {
|
||||
logger.fatal("next action is null");
|
||||
}
|
||||
switch (player.getNextAction()) {
|
||||
case PRIORITY:
|
||||
// logger.info("Priority for player:" + player.getName() + " turn: " + game.getTurnNum() + " phase: " + game.getTurnPhaseType() + " step: " + game.getTurnStepType());
|
||||
List<Ability> abilities;
|
||||
if (!USE_ACTION_CACHE)
|
||||
abilities = player.getPlayableOptions(game);
|
||||
else
|
||||
abilities = getPlayables(player, fullStateValue, game);
|
||||
for (Ability ability: abilities) {
|
||||
Game sim = game.copy();
|
||||
// logger.info("expand " + ability.toString());
|
||||
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
|
||||
simPlayer.activateAbility((ActivatedAbility)ability, sim);
|
||||
sim.resume();
|
||||
children.add(new MCTSNode(this, sim, ability));
|
||||
}
|
||||
break;
|
||||
case SELECT_ATTACKERS:
|
||||
// logger.info("Select attackers:" + player.getName());
|
||||
List<List<UUID>> attacks;
|
||||
if (!USE_ACTION_CACHE)
|
||||
attacks = player.getAttacks(game);
|
||||
else
|
||||
attacks = getAttacks(player, fullStateValue, game);
|
||||
UUID defenderId = game.getOpponents(player.getId()).iterator().next();
|
||||
for (List<UUID> attack: attacks) {
|
||||
Game sim = game.copy();
|
||||
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
|
||||
for (UUID attackerId: attack) {
|
||||
simPlayer.declareAttacker(attackerId, defenderId, sim, false);
|
||||
}
|
||||
sim.resume();
|
||||
children.add(new MCTSNode(this, sim, sim.getCombat()));
|
||||
}
|
||||
break;
|
||||
case SELECT_BLOCKERS:
|
||||
// logger.info("Select blockers:" + player.getName());
|
||||
List<List<List<UUID>>> blocks;
|
||||
if (!USE_ACTION_CACHE)
|
||||
blocks = player.getBlocks(game);
|
||||
else
|
||||
blocks = getBlocks(player, fullStateValue, game);
|
||||
for (List<List<UUID>> block: blocks) {
|
||||
Game sim = game.copy();
|
||||
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
|
||||
List<CombatGroup> groups = sim.getCombat().getGroups();
|
||||
for (int i = 0; i < groups.size(); i++) {
|
||||
if (i < block.size()) {
|
||||
for (UUID blockerId: block.get(i)) {
|
||||
simPlayer.declareBlocker(simPlayer.getId(), blockerId, groups.get(i).getAttackers().get(0), sim);
|
||||
}
|
||||
}
|
||||
}
|
||||
sim.resume();
|
||||
children.add(new MCTSNode(this, sim, sim.getCombat()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
children.addAll(MCTSNextActionFactory.createNextAction(player.getNextAction()).performNextAction(this, player, game, fullStateValue));
|
||||
game = null;
|
||||
}
|
||||
|
||||
|
|
@ -462,8 +405,8 @@ public class MCTSNode {
|
|||
private static long attacksMiss = 0;
|
||||
private static long blocksHit = 0;
|
||||
private static long blocksMiss = 0;
|
||||
|
||||
private static List<Ability> getPlayables(MCTSPlayer player, String state, Game game) {
|
||||
|
||||
protected static List<Ability> getPlayables(MCTSPlayer player, String state, Game game) {
|
||||
if (playablesCache.containsKey(state)) {
|
||||
playablesHit++;
|
||||
return playablesCache.get(state);
|
||||
|
|
@ -475,8 +418,8 @@ public class MCTSNode {
|
|||
return abilities;
|
||||
}
|
||||
}
|
||||
|
||||
private static List<List<UUID>> getAttacks(MCTSPlayer player, String state, Game game) {
|
||||
|
||||
protected static List<List<UUID>> getAttacks(MCTSPlayer player, String state, Game game) {
|
||||
if (attacksCache.containsKey(state)) {
|
||||
attacksHit++;
|
||||
return attacksCache.get(state);
|
||||
|
|
@ -489,7 +432,7 @@ public class MCTSNode {
|
|||
}
|
||||
}
|
||||
|
||||
private static List<List<List<UUID>>> getBlocks(MCTSPlayer player, String state, Game game) {
|
||||
protected static List<List<List<UUID>>> getBlocks(MCTSPlayer player, String state, Game game) {
|
||||
if (blocksCache.containsKey(state)) {
|
||||
blocksHit++;
|
||||
return blocksCache.get(state);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
package mage.player.ai;
|
||||
|
||||
import mage.game.Game;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface MCTSNodeNextAction {
|
||||
List<MCTSNode> performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue);
|
||||
}
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
package mage.player.ai;
|
||||
|
||||
import mage.abilities.Ability;
|
||||
import mage.abilities.ActivatedAbility;
|
||||
import mage.game.Game;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class PriorityNextAction implements MCTSNodeNextAction{
|
||||
|
||||
@Override
|
||||
public List<MCTSNode> performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue) {
|
||||
List<MCTSNode> children = new ArrayList<>();
|
||||
List<Ability> abilities;
|
||||
if (!MCTSNode.USE_ACTION_CACHE)
|
||||
abilities = player.getPlayableOptions(game);
|
||||
else
|
||||
abilities = MCTSNode.getPlayables(player, fullStateValue, game);
|
||||
for (Ability ability: abilities) {
|
||||
Game sim = game.copy();
|
||||
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
|
||||
simPlayer.activateAbility((ActivatedAbility)ability, sim);
|
||||
sim.resume();
|
||||
children.add(new MCTSNode(node, sim, ability));
|
||||
}
|
||||
|
||||
return children;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
package mage.player.ai;
|
||||
|
||||
import mage.game.Game;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static mage.player.ai.MCTSNode.getAttacks;
|
||||
|
||||
public class SelectAttackersNextAction implements MCTSNodeNextAction{
|
||||
@Override
|
||||
public List<MCTSNode> performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue) {
|
||||
List<MCTSNode> children = new ArrayList<>();
|
||||
List<List<UUID>> attacks;
|
||||
if (!MCTSNode.USE_ACTION_CACHE)
|
||||
attacks = player.getAttacks(game);
|
||||
else
|
||||
attacks = getAttacks(player, fullStateValue, game);
|
||||
UUID defenderId = game.getOpponents(player.getId()).iterator().next();
|
||||
for (List<UUID> attack: attacks) {
|
||||
Game sim = game.copy();
|
||||
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
|
||||
for (UUID attackerId: attack) {
|
||||
simPlayer.declareAttacker(attackerId, defenderId, sim, false);
|
||||
}
|
||||
sim.resume();
|
||||
children.add(new MCTSNode(node, sim, sim.getCombat()));
|
||||
}
|
||||
|
||||
return children;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
package mage.player.ai;
|
||||
|
||||
import mage.game.Game;
|
||||
import mage.game.combat.CombatGroup;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static mage.player.ai.MCTSNode.getBlocks;
|
||||
|
||||
public class SelectBlockersNextAction implements MCTSNodeNextAction{
|
||||
@Override
|
||||
public List<MCTSNode> performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue) {
|
||||
List<MCTSNode> children = new ArrayList<>();
|
||||
List<List<List<UUID>>> blocks;
|
||||
if (!MCTSNode.USE_ACTION_CACHE)
|
||||
blocks = player.getBlocks(game);
|
||||
else
|
||||
blocks = getBlocks(player, fullStateValue, game);
|
||||
for (List<List<UUID>> block : blocks) {
|
||||
Game sim = game.copy();
|
||||
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
|
||||
List<CombatGroup> groups = sim.getCombat().getGroups();
|
||||
for (int i = 0; i < groups.size(); i++) {
|
||||
if (i < block.size()) {
|
||||
for (UUID blockerId : block.get(i)) {
|
||||
simPlayer.declareBlocker(simPlayer.getId(), blockerId, groups.get(i).getAttackers().get(0), sim);
|
||||
}
|
||||
}
|
||||
}
|
||||
sim.resume();
|
||||
children.add(new MCTSNode(node, sim, sim.getCombat()));
|
||||
}
|
||||
|
||||
return children;
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue