mirror of
https://github.com/magefree/mage.git
synced 2025-12-20 02:30:08 -08:00
refactor: improved usage of NextAction in mcts AI code (#11480)
Replaced conditional of selecting next action with runtime polymorphism, to increase the readability and easier for future changes by following Open/Close principle.
This commit is contained in:
parent
2ef9439773
commit
7913c01ec3
6 changed files with 137 additions and 63 deletions
|
|
@ -0,0 +1,21 @@
|
||||||
|
package mage.player.ai;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
|
||||||
|
public class MCTSNextActionFactory {
|
||||||
|
private static final HashMap<MCTSPlayer.NextAction, MCTSNodeNextAction> strategyMap = new HashMap<>();
|
||||||
|
|
||||||
|
static {
|
||||||
|
strategyMap.put(MCTSPlayer.NextAction.PRIORITY, new PriorityNextAction());
|
||||||
|
strategyMap.put(MCTSPlayer.NextAction.SELECT_BLOCKERS, new SelectBlockersNextAction());
|
||||||
|
strategyMap.put(MCTSPlayer.NextAction.SELECT_ATTACKERS, new SelectAttackersNextAction());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static MCTSNodeNextAction createNextAction(MCTSPlayer.NextAction nextAction) {
|
||||||
|
MCTSNodeNextAction strategy = strategyMap.get(nextAction);
|
||||||
|
if (strategy == null) {
|
||||||
|
throw new IllegalArgumentException("Unsupported action: " + nextAction);
|
||||||
|
}
|
||||||
|
return strategy;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -127,64 +127,7 @@ public class MCTSNode {
|
||||||
if (player.getNextAction() == null) {
|
if (player.getNextAction() == null) {
|
||||||
logger.fatal("next action is null");
|
logger.fatal("next action is null");
|
||||||
}
|
}
|
||||||
switch (player.getNextAction()) {
|
children.addAll(MCTSNextActionFactory.createNextAction(player.getNextAction()).performNextAction(this, player, game, fullStateValue));
|
||||||
case PRIORITY:
|
|
||||||
// logger.info("Priority for player:" + player.getName() + " turn: " + game.getTurnNum() + " phase: " + game.getTurnPhaseType() + " step: " + game.getTurnStepType());
|
|
||||||
List<Ability> abilities;
|
|
||||||
if (!USE_ACTION_CACHE)
|
|
||||||
abilities = player.getPlayableOptions(game);
|
|
||||||
else
|
|
||||||
abilities = getPlayables(player, fullStateValue, game);
|
|
||||||
for (Ability ability: abilities) {
|
|
||||||
Game sim = game.copy();
|
|
||||||
// logger.info("expand " + ability.toString());
|
|
||||||
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
|
|
||||||
simPlayer.activateAbility((ActivatedAbility)ability, sim);
|
|
||||||
sim.resume();
|
|
||||||
children.add(new MCTSNode(this, sim, ability));
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case SELECT_ATTACKERS:
|
|
||||||
// logger.info("Select attackers:" + player.getName());
|
|
||||||
List<List<UUID>> attacks;
|
|
||||||
if (!USE_ACTION_CACHE)
|
|
||||||
attacks = player.getAttacks(game);
|
|
||||||
else
|
|
||||||
attacks = getAttacks(player, fullStateValue, game);
|
|
||||||
UUID defenderId = game.getOpponents(player.getId()).iterator().next();
|
|
||||||
for (List<UUID> attack: attacks) {
|
|
||||||
Game sim = game.copy();
|
|
||||||
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
|
|
||||||
for (UUID attackerId: attack) {
|
|
||||||
simPlayer.declareAttacker(attackerId, defenderId, sim, false);
|
|
||||||
}
|
|
||||||
sim.resume();
|
|
||||||
children.add(new MCTSNode(this, sim, sim.getCombat()));
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case SELECT_BLOCKERS:
|
|
||||||
// logger.info("Select blockers:" + player.getName());
|
|
||||||
List<List<List<UUID>>> blocks;
|
|
||||||
if (!USE_ACTION_CACHE)
|
|
||||||
blocks = player.getBlocks(game);
|
|
||||||
else
|
|
||||||
blocks = getBlocks(player, fullStateValue, game);
|
|
||||||
for (List<List<UUID>> block: blocks) {
|
|
||||||
Game sim = game.copy();
|
|
||||||
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
|
|
||||||
List<CombatGroup> groups = sim.getCombat().getGroups();
|
|
||||||
for (int i = 0; i < groups.size(); i++) {
|
|
||||||
if (i < block.size()) {
|
|
||||||
for (UUID blockerId: block.get(i)) {
|
|
||||||
simPlayer.declareBlocker(simPlayer.getId(), blockerId, groups.get(i).getAttackers().get(0), sim);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sim.resume();
|
|
||||||
children.add(new MCTSNode(this, sim, sim.getCombat()));
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
game = null;
|
game = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -462,8 +405,8 @@ public class MCTSNode {
|
||||||
private static long attacksMiss = 0;
|
private static long attacksMiss = 0;
|
||||||
private static long blocksHit = 0;
|
private static long blocksHit = 0;
|
||||||
private static long blocksMiss = 0;
|
private static long blocksMiss = 0;
|
||||||
|
|
||||||
private static List<Ability> getPlayables(MCTSPlayer player, String state, Game game) {
|
protected static List<Ability> getPlayables(MCTSPlayer player, String state, Game game) {
|
||||||
if (playablesCache.containsKey(state)) {
|
if (playablesCache.containsKey(state)) {
|
||||||
playablesHit++;
|
playablesHit++;
|
||||||
return playablesCache.get(state);
|
return playablesCache.get(state);
|
||||||
|
|
@ -475,8 +418,8 @@ public class MCTSNode {
|
||||||
return abilities;
|
return abilities;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<List<UUID>> getAttacks(MCTSPlayer player, String state, Game game) {
|
protected static List<List<UUID>> getAttacks(MCTSPlayer player, String state, Game game) {
|
||||||
if (attacksCache.containsKey(state)) {
|
if (attacksCache.containsKey(state)) {
|
||||||
attacksHit++;
|
attacksHit++;
|
||||||
return attacksCache.get(state);
|
return attacksCache.get(state);
|
||||||
|
|
@ -489,7 +432,7 @@ public class MCTSNode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<List<List<UUID>>> getBlocks(MCTSPlayer player, String state, Game game) {
|
protected static List<List<List<UUID>>> getBlocks(MCTSPlayer player, String state, Game game) {
|
||||||
if (blocksCache.containsKey(state)) {
|
if (blocksCache.containsKey(state)) {
|
||||||
blocksHit++;
|
blocksHit++;
|
||||||
return blocksCache.get(state);
|
return blocksCache.get(state);
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
package mage.player.ai;
|
||||||
|
|
||||||
|
import mage.game.Game;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface MCTSNodeNextAction {
|
||||||
|
List<MCTSNode> performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
package mage.player.ai;
|
||||||
|
|
||||||
|
import mage.abilities.Ability;
|
||||||
|
import mage.abilities.ActivatedAbility;
|
||||||
|
import mage.game.Game;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class PriorityNextAction implements MCTSNodeNextAction{
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<MCTSNode> performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue) {
|
||||||
|
List<MCTSNode> children = new ArrayList<>();
|
||||||
|
List<Ability> abilities;
|
||||||
|
if (!MCTSNode.USE_ACTION_CACHE)
|
||||||
|
abilities = player.getPlayableOptions(game);
|
||||||
|
else
|
||||||
|
abilities = MCTSNode.getPlayables(player, fullStateValue, game);
|
||||||
|
for (Ability ability: abilities) {
|
||||||
|
Game sim = game.copy();
|
||||||
|
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
|
||||||
|
simPlayer.activateAbility((ActivatedAbility)ability, sim);
|
||||||
|
sim.resume();
|
||||||
|
children.add(new MCTSNode(node, sim, ability));
|
||||||
|
}
|
||||||
|
|
||||||
|
return children;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
package mage.player.ai;
|
||||||
|
|
||||||
|
import mage.game.Game;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
import static mage.player.ai.MCTSNode.getAttacks;
|
||||||
|
|
||||||
|
public class SelectAttackersNextAction implements MCTSNodeNextAction{
|
||||||
|
@Override
|
||||||
|
public List<MCTSNode> performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue) {
|
||||||
|
List<MCTSNode> children = new ArrayList<>();
|
||||||
|
List<List<UUID>> attacks;
|
||||||
|
if (!MCTSNode.USE_ACTION_CACHE)
|
||||||
|
attacks = player.getAttacks(game);
|
||||||
|
else
|
||||||
|
attacks = getAttacks(player, fullStateValue, game);
|
||||||
|
UUID defenderId = game.getOpponents(player.getId()).iterator().next();
|
||||||
|
for (List<UUID> attack: attacks) {
|
||||||
|
Game sim = game.copy();
|
||||||
|
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
|
||||||
|
for (UUID attackerId: attack) {
|
||||||
|
simPlayer.declareAttacker(attackerId, defenderId, sim, false);
|
||||||
|
}
|
||||||
|
sim.resume();
|
||||||
|
children.add(new MCTSNode(node, sim, sim.getCombat()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return children;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,38 @@
|
||||||
|
package mage.player.ai;
|
||||||
|
|
||||||
|
import mage.game.Game;
|
||||||
|
import mage.game.combat.CombatGroup;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
import static mage.player.ai.MCTSNode.getBlocks;
|
||||||
|
|
||||||
|
public class SelectBlockersNextAction implements MCTSNodeNextAction{
|
||||||
|
@Override
|
||||||
|
public List<MCTSNode> performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue) {
|
||||||
|
List<MCTSNode> children = new ArrayList<>();
|
||||||
|
List<List<List<UUID>>> blocks;
|
||||||
|
if (!MCTSNode.USE_ACTION_CACHE)
|
||||||
|
blocks = player.getBlocks(game);
|
||||||
|
else
|
||||||
|
blocks = getBlocks(player, fullStateValue, game);
|
||||||
|
for (List<List<UUID>> block : blocks) {
|
||||||
|
Game sim = game.copy();
|
||||||
|
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
|
||||||
|
List<CombatGroup> groups = sim.getCombat().getGroups();
|
||||||
|
for (int i = 0; i < groups.size(); i++) {
|
||||||
|
if (i < block.size()) {
|
||||||
|
for (UUID blockerId : block.get(i)) {
|
||||||
|
simPlayer.declareBlocker(simPlayer.getId(), blockerId, groups.get(i).getAttackers().get(0), sim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sim.resume();
|
||||||
|
children.add(new MCTSNode(node, sim, sim.getCombat()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return children;
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue