fixes + optimizations + updates to monte carlo ai

This commit is contained in:
BetaSteward 2012-01-24 22:51:31 -05:00
parent 23616432e4
commit 7fce6c552d
16 changed files with 312 additions and 228 deletions

View file

@ -38,6 +38,10 @@ import mage.Constants.Zone;
import mage.abilities.Ability;
import mage.abilities.ActivatedAbility;
import mage.abilities.common.PassAbility;
import mage.abilities.costs.mana.GenericManaCost;
import mage.abilities.costs.mana.ManaCost;
import mage.abilities.costs.mana.ManaCosts;
import mage.abilities.costs.mana.VariableManaCost;
import mage.cards.Card;
import mage.game.Game;
import mage.game.combat.Combat;
@ -52,8 +56,9 @@ import org.apache.log4j.Logger;
*/
public class ComputerPlayerMCTS extends ComputerPlayer<ComputerPlayerMCTS> implements Player {
private static final int THINK_MIN_RATIO = 20;
private static final int THINK_MIN_RATIO = 40;
private static final int THINK_MAX_RATIO = 100;
private static final boolean USE_MULTIPLE_THREADS = false;
protected transient MCTSNode root;
protected int maxThinkTime;
@ -84,7 +89,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer<ComputerPlayerMCTS> imple
@Override
public boolean priority(Game game) {
if (game.getStep().getType() == PhaseStep.DRAW)
if (game.getStep().getType() == PhaseStep.PRECOMBAT_MAIN)
logList("computer player " + name + " hand: ", new ArrayList(hand.getCards(game)));
game.firePriorityEvent(playerId);
getNextAction(game, NextAction.PRIORITY);
@ -111,7 +116,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer<ComputerPlayerMCTS> imple
protected void getNextAction(Game game, NextAction nextAction) {
if (root != null) {
MCTSNode newRoot = null;
MCTSNode newRoot;
newRoot = root.getMatchingState(game.getState().getValue(false, game));
if (newRoot != null) {
newRoot.emancipate();
@ -161,7 +166,8 @@ public class ComputerPlayerMCTS extends ComputerPlayer<ComputerPlayerMCTS> imple
//
// @Override
// public boolean chooseUse(Outcome outcome, String message, Game game) {
// throw new UnsupportedOperationException("Not supported yet.");
// getNextAction(game, NextAction.CHOOSE_USE);
// return root.get
// }
//
// @Override
@ -174,11 +180,20 @@ public class ComputerPlayerMCTS extends ComputerPlayer<ComputerPlayerMCTS> imple
// throw new UnsupportedOperationException("Not supported yet.");
// }
// @Override
// public boolean playXMana(VariableManaCost cost, ManaCosts<ManaCost> costs, Game game) {
// throw new UnsupportedOperationException("Not supported yet.");
// }
//
@Override
public boolean playXMana(VariableManaCost cost, ManaCosts<ManaCost> costs, Game game) {
//MCTSPlayer.simulateVariableCosts method adds a generic mana cost for each option
for (ManaCost manaCost: costs) {
if (manaCost instanceof GenericManaCost) {
cost.setPayment(manaCost.getPayment());
logger.debug("using X = " + cost.getPayment().count());
break;
}
}
cost.setPaid();
return true;
}
// @Override
// public int chooseEffect(List<ReplacementEffect> rEffects, Game game) {
// throw new UnsupportedOperationException("Not supported yet.");
@ -261,68 +276,69 @@ public class ComputerPlayerMCTS extends ComputerPlayer<ComputerPlayerMCTS> imple
logger.info("applyMCTS - Thinking for " + (endTime - startTime)/1000000000.0 + "s");
if (thinkTime > 0) {
// List<MCTSExecutor> tasks = new ArrayList<MCTSExecutor>();
// for (int i = 0; i < cores; i++) {
// Game sim = createMCTSGame(game);
// MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId);
// player.setNextAction(action);
// MCTSExecutor exec = new MCTSExecutor(sim, playerId, thinkTime);
// tasks.add(exec);
// }
//
// try {
// pool.invokeAll(tasks);
// } catch (InterruptedException ex) {
// logger.warn("applyMCTS interrupted");
// }
//
// for (MCTSExecutor task: tasks) {
// root.merge(task.getRoot());
// task.clear();
// }
// tasks.clear();
MCTSNode current;
int simCount = 0;
while (true) {
long currentTime = System.nanoTime();
if (currentTime > endTime)
break;
current = root;
// Selection
while (!current.isLeaf()) {
current = current.select(this.playerId);
if (USE_MULTIPLE_THREADS) {
List<MCTSExecutor> tasks = new ArrayList<MCTSExecutor>();
for (int i = 0; i < cores; i++) {
Game sim = createMCTSGame(game);
MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId);
player.setNextAction(action);
MCTSExecutor exec = new MCTSExecutor(sim, playerId, thinkTime);
tasks.add(exec);
}
int result;
if (!current.isTerminal()) {
// Expansion
current.expand();
try {
pool.invokeAll(tasks);
} catch (InterruptedException ex) {
logger.warn("applyMCTS interrupted");
}
// Simulation
current = current.select(this.playerId);
result = current.simulate(this.playerId);
simCount++;
for (MCTSExecutor task: tasks) {
root.merge(task.getRoot());
task.clear();
}
else {
result = current.isWinner(this.playerId)?1:-1;
}
// Backpropagation
current.backpropagate(result);
tasks.clear();
}
else {
MCTSNode current;
int simCount = 0;
while (true) {
long currentTime = System.nanoTime();
if (currentTime > endTime)
break;
current = root;
logger.info("Simulated " + simCount + " games - nodes in tree: " + root.size());
// Selection
while (!current.isLeaf()) {
current = current.select(this.playerId);
}
int result;
if (!current.isTerminal()) {
// Expansion
current.expand();
// Simulation
current = current.select(this.playerId);
result = current.simulate(this.playerId);
simCount++;
}
else {
result = current.isWinner(this.playerId)?1:-1;
}
// Backpropagation
current.backpropagate(result);
}
logger.info("Simulated " + simCount + " games - nodes in tree: " + root.size());
}
displayMemory();
}
// root.print(1);
return;
}
//try to ensure that there are at least THINK_MIN_RATIO simulations per node at all times
private int calculateThinkTime(Game game, NextAction action) {
int thinkTime = 0;
int thinkTime;
int nodeSizeRatio = 0;
if (root.getNumChildren() > 0)
nodeSizeRatio = root.getVisits() / root.getNumChildren();