refactor: improved usage of NextAction in mcts AI code (#11480)

Replaced conditional of selecting next action with runtime polymorphism, to increase the readability and easier for future changes by following Open/Close principle.
This commit is contained in:
Tirth Bharatiya 2023-11-28 12:42:59 +05:30 committed by GitHub
parent 2ef9439773
commit 7913c01ec3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 137 additions and 63 deletions

View file

@ -0,0 +1,21 @@
package mage.player.ai;
import java.util.HashMap;
public class MCTSNextActionFactory {
private static final HashMap<MCTSPlayer.NextAction, MCTSNodeNextAction> strategyMap = new HashMap<>();
static {
strategyMap.put(MCTSPlayer.NextAction.PRIORITY, new PriorityNextAction());
strategyMap.put(MCTSPlayer.NextAction.SELECT_BLOCKERS, new SelectBlockersNextAction());
strategyMap.put(MCTSPlayer.NextAction.SELECT_ATTACKERS, new SelectAttackersNextAction());
}
public static MCTSNodeNextAction createNextAction(MCTSPlayer.NextAction nextAction) {
MCTSNodeNextAction strategy = strategyMap.get(nextAction);
if (strategy == null) {
throw new IllegalArgumentException("Unsupported action: " + nextAction);
}
return strategy;
}
}

View file

@ -127,64 +127,7 @@ public class MCTSNode {
if (player.getNextAction() == null) {
logger.fatal("next action is null");
}
switch (player.getNextAction()) {
case PRIORITY:
// logger.info("Priority for player:" + player.getName() + " turn: " + game.getTurnNum() + " phase: " + game.getTurnPhaseType() + " step: " + game.getTurnStepType());
List<Ability> abilities;
if (!USE_ACTION_CACHE)
abilities = player.getPlayableOptions(game);
else
abilities = getPlayables(player, fullStateValue, game);
for (Ability ability: abilities) {
Game sim = game.copy();
// logger.info("expand " + ability.toString());
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
simPlayer.activateAbility((ActivatedAbility)ability, sim);
sim.resume();
children.add(new MCTSNode(this, sim, ability));
}
break;
case SELECT_ATTACKERS:
// logger.info("Select attackers:" + player.getName());
List<List<UUID>> attacks;
if (!USE_ACTION_CACHE)
attacks = player.getAttacks(game);
else
attacks = getAttacks(player, fullStateValue, game);
UUID defenderId = game.getOpponents(player.getId()).iterator().next();
for (List<UUID> attack: attacks) {
Game sim = game.copy();
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
for (UUID attackerId: attack) {
simPlayer.declareAttacker(attackerId, defenderId, sim, false);
}
sim.resume();
children.add(new MCTSNode(this, sim, sim.getCombat()));
}
break;
case SELECT_BLOCKERS:
// logger.info("Select blockers:" + player.getName());
List<List<List<UUID>>> blocks;
if (!USE_ACTION_CACHE)
blocks = player.getBlocks(game);
else
blocks = getBlocks(player, fullStateValue, game);
for (List<List<UUID>> block: blocks) {
Game sim = game.copy();
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
List<CombatGroup> groups = sim.getCombat().getGroups();
for (int i = 0; i < groups.size(); i++) {
if (i < block.size()) {
for (UUID blockerId: block.get(i)) {
simPlayer.declareBlocker(simPlayer.getId(), blockerId, groups.get(i).getAttackers().get(0), sim);
}
}
}
sim.resume();
children.add(new MCTSNode(this, sim, sim.getCombat()));
}
break;
}
children.addAll(MCTSNextActionFactory.createNextAction(player.getNextAction()).performNextAction(this, player, game, fullStateValue));
game = null;
}
@ -462,8 +405,8 @@ public class MCTSNode {
private static long attacksMiss = 0;
private static long blocksHit = 0;
private static long blocksMiss = 0;
private static List<Ability> getPlayables(MCTSPlayer player, String state, Game game) {
protected static List<Ability> getPlayables(MCTSPlayer player, String state, Game game) {
if (playablesCache.containsKey(state)) {
playablesHit++;
return playablesCache.get(state);
@ -475,8 +418,8 @@ public class MCTSNode {
return abilities;
}
}
private static List<List<UUID>> getAttacks(MCTSPlayer player, String state, Game game) {
protected static List<List<UUID>> getAttacks(MCTSPlayer player, String state, Game game) {
if (attacksCache.containsKey(state)) {
attacksHit++;
return attacksCache.get(state);
@ -489,7 +432,7 @@ public class MCTSNode {
}
}
private static List<List<List<UUID>>> getBlocks(MCTSPlayer player, String state, Game game) {
protected static List<List<List<UUID>>> getBlocks(MCTSPlayer player, String state, Game game) {
if (blocksCache.containsKey(state)) {
blocksHit++;
return blocksCache.get(state);

View file

@ -0,0 +1,9 @@
package mage.player.ai;
import mage.game.Game;
import java.util.List;
public interface MCTSNodeNextAction {
List<MCTSNode> performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue);
}

View file

@ -0,0 +1,30 @@
package mage.player.ai;
import mage.abilities.Ability;
import mage.abilities.ActivatedAbility;
import mage.game.Game;
import java.util.ArrayList;
import java.util.List;
public class PriorityNextAction implements MCTSNodeNextAction{
@Override
public List<MCTSNode> performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue) {
List<MCTSNode> children = new ArrayList<>();
List<Ability> abilities;
if (!MCTSNode.USE_ACTION_CACHE)
abilities = player.getPlayableOptions(game);
else
abilities = MCTSNode.getPlayables(player, fullStateValue, game);
for (Ability ability: abilities) {
Game sim = game.copy();
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
simPlayer.activateAbility((ActivatedAbility)ability, sim);
sim.resume();
children.add(new MCTSNode(node, sim, ability));
}
return children;
}
}

View file

@ -0,0 +1,33 @@
package mage.player.ai;
import mage.game.Game;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import static mage.player.ai.MCTSNode.getAttacks;
public class SelectAttackersNextAction implements MCTSNodeNextAction{
@Override
public List<MCTSNode> performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue) {
List<MCTSNode> children = new ArrayList<>();
List<List<UUID>> attacks;
if (!MCTSNode.USE_ACTION_CACHE)
attacks = player.getAttacks(game);
else
attacks = getAttacks(player, fullStateValue, game);
UUID defenderId = game.getOpponents(player.getId()).iterator().next();
for (List<UUID> attack: attacks) {
Game sim = game.copy();
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
for (UUID attackerId: attack) {
simPlayer.declareAttacker(attackerId, defenderId, sim, false);
}
sim.resume();
children.add(new MCTSNode(node, sim, sim.getCombat()));
}
return children;
}
}

View file

@ -0,0 +1,38 @@
package mage.player.ai;
import mage.game.Game;
import mage.game.combat.CombatGroup;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import static mage.player.ai.MCTSNode.getBlocks;
public class SelectBlockersNextAction implements MCTSNodeNextAction{
@Override
public List<MCTSNode> performNextAction(MCTSNode node, MCTSPlayer player, Game game, String fullStateValue) {
List<MCTSNode> children = new ArrayList<>();
List<List<List<UUID>>> blocks;
if (!MCTSNode.USE_ACTION_CACHE)
blocks = player.getBlocks(game);
else
blocks = getBlocks(player, fullStateValue, game);
for (List<List<UUID>> block : blocks) {
Game sim = game.copy();
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
List<CombatGroup> groups = sim.getCombat().getGroups();
for (int i = 0; i < groups.size(); i++) {
if (i < block.size()) {
for (UUID blockerId : block.get(i)) {
simPlayer.declareBlocker(simPlayer.getId(), blockerId, groups.get(i).getAttackers().get(0), sim);
}
}
}
sim.resume();
children.add(new MCTSNode(node, sim, sim.getCombat()));
}
return children;
}
}