AI: fixed any MCTS (Monte Carlo) simulations fail, added catch error on MCTS fails (#10154);

This commit is contained in:
Oleg Agafonov 2024-06-10 20:37:32 +04:00
parent fe4814680e
commit fe55e67ca3

View file

@ -12,15 +12,13 @@ import mage.game.combat.Combat;
import mage.game.combat.CombatGroup;
import mage.player.ai.MCTSPlayer.NextAction;
import mage.players.Player;
import mage.util.ThreadUtils;
import org.apache.log4j.Logger;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.*;
/**
* @author BetaSteward_at_googlemail.com
@ -37,6 +35,8 @@ public class ComputerPlayerMCTS extends ComputerPlayer {
private static final Logger logger = Logger.getLogger(ComputerPlayerMCTS.class);
private int poolSize;
private ExecutorService threadPoolSimulations = null;
public ComputerPlayerMCTS(String name, RangeOfInfluence range, int skill) {
super(name, range);
human = false;
@ -161,7 +161,22 @@ public class ComputerPlayerMCTS extends ComputerPlayer {
if (thinkTime > 0) {
if (USE_MULTIPLE_THREADS) {
ExecutorService pool = Executors.newFixedThreadPool(poolSize);
if (this.threadPoolSimulations == null) {
// same params as Executors.newFixedThreadPool
// no needs errors check in afterExecute here cause that pool used for FutureTask with result check already
this.threadPoolSimulations = new ThreadPoolExecutor(
poolSize,
poolSize,
0L,
TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<>(),
r -> {
Thread thread = new Thread(r);
thread.setName(ThreadUtils.THREAD_PREFIX_AI_SIMULATION + "-MCTS-" + thread.getId());
return thread;
});
}
List<MCTSExecutor> tasks = new ArrayList<>();
for (int i = 0; i < poolSize; i++) {
Game sim = createMCTSGame(game);
@ -172,11 +187,18 @@ public class ComputerPlayerMCTS extends ComputerPlayer {
}
try {
pool.invokeAll(tasks, thinkTime, TimeUnit.SECONDS);
pool.awaitTermination(1, TimeUnit.SECONDS);
pool.shutdownNow();
} catch (InterruptedException | RejectedExecutionException ex) {
logger.warn("applyMCTS interrupted");
List<Future<Boolean>> runningTasks = threadPoolSimulations.invokeAll(tasks, thinkTime, TimeUnit.SECONDS);
for (Future<Boolean> runningTask : runningTasks) {
runningTask.get();
}
} catch (InterruptedException | CancellationException e) {
logger.warn("applyMCTS timeout");
} catch (ExecutionException e) {
// real games: must catch and log
// unit tests: must raise again for fast fail
if (this.isTestsMode()) {
throw new IllegalStateException("One of the simulated games raise the error: " + e, e);
}
}
int simCount = 0;
@ -280,6 +302,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer {
Player origPlayer = game.getState().getPlayers().get(copyPlayer.getId());
MCTSPlayer newPlayer = new MCTSPlayer(copyPlayer.getId());
newPlayer.restore(origPlayer);
newPlayer.setMatchPlayer(origPlayer.getMatchPlayer());
if (!newPlayer.getId().equals(playerId)) {
int handSize = newPlayer.getHand().size();
newPlayer.getLibrary().addAll(newPlayer.getHand().getCards(mcts), mcts);