single-threaded monte carlo + some fixes

This commit is contained in:
BetaSteward 2012-01-16 18:55:35 -05:00
parent 2e21b7197b
commit 377dd54fca
4 changed files with 106 additions and 51 deletions

View file

@ -52,7 +52,8 @@ import org.apache.log4j.Logger;
*/
public class ComputerPlayerMCTS extends ComputerPlayer<ComputerPlayerMCTS> implements Player {
private static final int thinkTimeRatioThreshold = 20;
private static final int THINK_MIN_RATIO = 20;
private static final int THINK_MAX_RATIO = 100;
protected transient MCTSNode root;
protected int maxThinkTime;
@ -112,8 +113,10 @@ public class ComputerPlayerMCTS extends ComputerPlayer<ComputerPlayerMCTS> imple
if (root != null) {
MCTSNode newRoot = null;
newRoot = root.getMatchingState(game.getState().getValue(false, game));
if (newRoot != null)
if (newRoot != null) {
newRoot.emancipate();
logger.info("choose action:" + newRoot.getAction() + " success ratio: " + newRoot.getWinRatio());
}
else
logger.info("unable to find matching state");
root = newRoot;
@ -258,28 +261,58 @@ public class ComputerPlayerMCTS extends ComputerPlayer<ComputerPlayerMCTS> imple
logger.info("applyMCTS - Thinking for " + (endTime - startTime)/1000000000.0 + "s");
if (thinkTime > 0) {
List<MCTSExecutor> tasks = new ArrayList<MCTSExecutor>();
for (int i = 0; i < cores; i++) {
Game sim = createMCTSGame(game);
MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId);
player.setNextAction(action);
MCTSExecutor exec = new MCTSExecutor(sim, playerId, thinkTime);
tasks.add(exec);
// List<MCTSExecutor> tasks = new ArrayList<MCTSExecutor>();
// for (int i = 0; i < cores; i++) {
// Game sim = createMCTSGame(game);
// MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId);
// player.setNextAction(action);
// MCTSExecutor exec = new MCTSExecutor(sim, playerId, thinkTime);
// tasks.add(exec);
// }
//
// try {
// pool.invokeAll(tasks);
// } catch (InterruptedException ex) {
// logger.warn("applyMCTS interrupted");
// }
//
// for (MCTSExecutor task: tasks) {
// root.merge(task.getRoot());
// task.clear();
// }
// tasks.clear();
MCTSNode current;
int simCount = 0;
while (true) {
long currentTime = System.nanoTime();
if (currentTime > endTime)
break;
current = root;
// Selection
while (!current.isLeaf()) {
current = current.select(this.playerId);
}
int result;
if (!current.isTerminal()) {
// Expansion
current.expand();
// Simulation
current = current.select(this.playerId);
result = current.simulate(this.playerId);
simCount++;
}
else {
result = current.isWinner(this.playerId)?1:-1;
}
// Backpropagation
current.backpropagate(result);
}
try {
pool.invokeAll(tasks);
} catch (InterruptedException ex) {
logger.warn("applyMCTS interrupted");
}
for (MCTSExecutor task: tasks) {
root.merge(task.getRoot());
task.clear();
}
tasks.clear();
logger.info("Created " + root.getNodeCount() + " nodes - size: " + root.size());
logger.info("Simulated " + simCount + " games - nodes in tree: " + root.size());
displayMemory();
}
@ -287,32 +320,38 @@ public class ComputerPlayerMCTS extends ComputerPlayer<ComputerPlayerMCTS> imple
return;
}
//try to ensure that there are at least 20 simulations per node at all times
//try to ensure that there are at least THINK_MIN_RATIO simulations per node at all times
private int calculateThinkTime(Game game, NextAction action) {
int thinkTime = 0;
int nodeSizeRatio = 0;
if (root.getNumChildren() > 0)
nodeSizeRatio = root.size() / root.getNumChildren();
nodeSizeRatio = root.getVisits() / root.getNumChildren();
logger.info("Ratio: " + nodeSizeRatio);
PhaseStep curStep = game.getStep().getType();
if (action == NextAction.SELECT_ATTACKERS || action == NextAction.SELECT_BLOCKERS) {
if (nodeSizeRatio < thinkTimeRatioThreshold) {
if (nodeSizeRatio < THINK_MIN_RATIO) {
thinkTime = maxThinkTime;
}
else if (nodeSizeRatio >= THINK_MAX_RATIO) {
thinkTime = 0;
}
else {
thinkTime = maxThinkTime / 2;
}
}
else if (game.getActivePlayerId().equals(playerId) && (curStep == PhaseStep.PRECOMBAT_MAIN || curStep == PhaseStep.POSTCOMBAT_MAIN)) {
if (nodeSizeRatio < thinkTimeRatioThreshold) {
else if (game.getActivePlayerId().equals(playerId) && (curStep == PhaseStep.PRECOMBAT_MAIN || curStep == PhaseStep.POSTCOMBAT_MAIN) && game.getStack().isEmpty()) {
if (nodeSizeRatio < THINK_MIN_RATIO) {
thinkTime = maxThinkTime;
}
else if (nodeSizeRatio >= THINK_MAX_RATIO) {
thinkTime = 0;
}
else {
thinkTime = maxThinkTime / 2;
}
}
else {
if (nodeSizeRatio < thinkTimeRatioThreshold) {
if (nodeSizeRatio < THINK_MIN_RATIO) {
thinkTime = maxThinkTime / 2;
}
else {