Simulated AI: Fix multi-target spell simulation. (#2181)

* Simulated AI: Fix multi-target spell simulation.

Removes an incorrect check that was comparing two semantically different things (number of possible to choose from vs. number of targets chosen). Adds a test using Incremental Growth, where number of possible targets is 5, but the spell requires only 3 to be chosen.

* Fix the root issue and eliminate incorrect simulations.

* Fix infinite loop with invalid targets.

* Fix logic.
This commit is contained in:
asvitkine
2023-01-12 06:38:27 -05:00
committed by GitHub
parent 876668c370
commit c207a2def9
4 changed files with 150 additions and 70 deletions

View File

@@ -11,6 +11,10 @@ public class MultiTargetSelector {
public static class Targets { public static class Targets {
private ArrayList<PossibleTargetSelector.Targets> targets; private ArrayList<PossibleTargetSelector.Targets> targets;
public int size() {
return targets.size();
}
@Override @Override
public String toString() { public String toString() {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
@@ -24,8 +28,8 @@ public class MultiTargetSelector {
} }
} }
private List<PossibleTargetSelector> selectors; private final List<PossibleTargetSelector> selectors;
private List<SpellAbility> targetingSAs; private final List<SpellAbility> targetingSAs;
private int currentIndex; private int currentIndex;
public MultiTargetSelector(SpellAbility sa, List<AbilitySub> plannedSubs) { public MultiTargetSelector(SpellAbility sa, List<AbilitySub> plannedSubs) {
@@ -52,8 +56,8 @@ public class MultiTargetSelector {
public Targets getLastSelectedTargets() { public Targets getLastSelectedTargets() {
Targets targets = new Targets(); Targets targets = new Targets();
targets.targets = new ArrayList<>(selectors.size()); targets.targets = new ArrayList<>(selectors.size());
for (int i = 0; i < selectors.size(); i++) { for (PossibleTargetSelector selector : selectors) {
targets.targets.add(selectors.get(i).getLastSelectedTargets()); targets.targets.add(selector.getLastSelectedTargets());
} }
return targets; return targets;
} }
@@ -78,34 +82,62 @@ public class MultiTargetSelector {
currentIndex = -1; currentIndex = -1;
} }
public void selectTargetsByIndex(int i) { public boolean selectTargetsByIndex(int i) {
// The caller is telling us to select the i-th possible set of targets.
if (i < currentIndex) { if (i < currentIndex) {
reset(); reset();
} }
while (currentIndex < i) { while (currentIndex < i) {
selectNextTargets(); if (!selectNextTargets()) {
return false;
}
} }
return true;
} }
public boolean selectNextTargets() { private boolean selectTargetsStartingFrom(int selectorIndex) {
if (currentIndex == -1) { // Don't reset the current selector, as it still has the correct list of targets set and has
for (PossibleTargetSelector selector : selectors) { // to remember its current/next target index. Subsequent selectors need a reset since their
if (!selector.selectNextTargets()) { // possible targets may change based on what was chosen for earlier ones.
if (selectors.get(selectorIndex).selectNextTargets()) {
for (int i = selectorIndex + 1; i < selectors.size(); i++) {
selectors.get(i).reset();
if (!selectors.get(i).selectNextTargets()) {
return false; return false;
} }
} }
currentIndex = 0;
return true; return true;
} }
for (int i = selectors.size() - 1; i >= 0; i--) { return false;
if (selectors.get(i).selectNextTargets()) { }
currentIndex++;
public boolean selectNextTargets() {
if (selectors.size() == 0) {
return false;
}
if (currentIndex == -1) {
// Select the first set of targets (calls selectNextTargets() on each selector).
if (selectTargetsStartingFrom(0)) {
currentIndex = 0;
return true; return true;
} }
selectors.get(i).reset(); // No possible targets.
selectors.get(i).selectNextTargets(); return false;
} }
return false; // Subsequent call, first try selecting a new target for the last selector. If that doesn't
// work, backtrack (decrement selector index) and try selecting targets from there.
// This approach ensures that leaf selectors (end of list) are advanced first, before
// previous ones, so that we get an AA,AB,BA,BB ordering.
int selectorIndex = selectors.size() - 1;
while (!selectTargetsStartingFrom(selectorIndex)) {
if (selectorIndex == 0) {
// No more possible targets.
return false;
}
selectorIndex--;
}
currentIndex++;
return true;
} }
private static boolean conditionsAreMet(SpellAbility saOrSubSa) { private static boolean conditionsAreMet(SpellAbility saOrSubSa) {

View File

@@ -17,12 +17,11 @@ import forge.game.spellability.TargetRestrictions;
public class PossibleTargetSelector { public class PossibleTargetSelector {
private final SpellAbility sa; private final SpellAbility sa;
private SpellAbility targetingSa; private final SpellAbility targetingSa;
private int targetingSaIndex; private final int targetingSaIndex;
private int maxTargets; private int maxTargets;
private TargetRestrictions tgt; private int nextTargetIndex;
private int targetIndex; private final List<GameObject> validTargets = new ArrayList<>();
private List<GameObject> validTargets;
public static class Targets { public static class Targets {
final int targetingSaIndex; final int targetingSaIndex;
@@ -36,7 +35,7 @@ public class PossibleTargetSelector {
this.targetIndex = targetIndex; this.targetIndex = targetIndex;
this.description = description; this.description = description;
if (targetIndex < 0 || targetIndex >= originalTargetCount) { if (targetIndex != -1 && (targetIndex < 0 || targetIndex >= originalTargetCount)) {
throw new IllegalArgumentException("Invalid targetIndex=" + targetIndex); throw new IllegalArgumentException("Invalid targetIndex=" + targetIndex);
} }
} }
@@ -51,12 +50,11 @@ public class PossibleTargetSelector {
this.sa = sa; this.sa = sa;
this.targetingSa = targetingSa; this.targetingSa = targetingSa;
this.targetingSaIndex = targetingSaIndex; this.targetingSaIndex = targetingSaIndex;
this.validTargets = new ArrayList<>(); reset();
generateValidTargets(sa.getHostCard().getController());
} }
public void reset() { public void reset() {
targetIndex = 0; nextTargetIndex = 0;
validTargets.clear(); validTargets.clear();
generateValidTargets(sa.getHostCard().getController()); generateValidTargets(sa.getHostCard().getController());
} }
@@ -67,7 +65,7 @@ public class PossibleTargetSelector {
} }
sa.setActivatingPlayer(player, true); sa.setActivatingPlayer(player, true);
targetingSa.resetTargets(); targetingSa.resetTargets();
tgt = targetingSa.getTargetRestrictions(); TargetRestrictions tgt = targetingSa.getTargetRestrictions();
maxTargets = tgt.getMaxTargets(sa.getHostCard(), targetingSa); maxTargets = tgt.getMaxTargets(sa.getHostCard(), targetingSa);
SimilarTargetSkipper skipper = new SimilarTargetSkipper(); SimilarTargetSkipper skipper = new SimilarTargetSkipper();
@@ -80,8 +78,8 @@ public class PossibleTargetSelector {
} }
private static class SimilarTargetSkipper { private static class SimilarTargetSkipper {
private ArrayListMultimap<String, Card> validTargetsMap = ArrayListMultimap.create(); private final ArrayListMultimap<String, Card> validTargetsMap = ArrayListMultimap.create();
private HashMap<Card, String> cardTypeStrings = new HashMap<>(); private final HashMap<Card, String> cardTypeStrings = new HashMap<>();
private HashMap<Card, Integer> creatureScores; private HashMap<Card, Integer> creatureScores;
private int getCreatureScore(Card c) { private int getCreatureScore(Card c) {
@@ -190,16 +188,7 @@ public class PossibleTargetSelector {
} }
public Targets getLastSelectedTargets() { public Targets getLastSelectedTargets() {
return new Targets(targetingSaIndex, validTargets.size(), targetIndex - 1, targetingSa.getTargets().toString()); return new Targets(targetingSaIndex, validTargets.size(), nextTargetIndex - 1, targetingSa.getTargets().toString());
}
public boolean selectTargetsByIndex(int targetIndex) {
if (targetIndex >= validTargets.size()) {
return false;
}
selectTargetsByIndexImpl(targetIndex);
this.targetIndex = targetIndex + 1;
return true;
} }
public boolean selectTargets(Targets targets) { public boolean selectTargets(Targets targets) {
@@ -208,16 +197,16 @@ public class PossibleTargetSelector {
return false; return false;
} }
selectTargetsByIndexImpl(targets.targetIndex); selectTargetsByIndexImpl(targets.targetIndex);
this.targetIndex = targets.targetIndex + 1; this.nextTargetIndex = targets.targetIndex + 1;
return true; return true;
} }
public boolean selectNextTargets() { public boolean selectNextTargets() {
if (targetIndex >= validTargets.size()) { if (nextTargetIndex >= validTargets.size()) {
return false; return false;
} }
selectTargetsByIndexImpl(targetIndex); selectTargetsByIndexImpl(nextTargetIndex);
targetIndex++; nextTargetIndex++;
return true; return true;
} }
} }

View File

@@ -137,4 +137,10 @@ public class SimulationTest {
protected Card addCard(String name, Player p) { protected Card addCard(String name, Player p) {
return addCardToZone(name, p, ZoneType.Battlefield); return addCardToZone(name, p, ZoneType.Battlefield);
} }
protected void addCards(String name, int count, Player p) {
for (int i = 0; i < count; i++) {
addCardToZone(name, p, ZoneType.Battlefield);
}
}
} }

View File

@@ -95,11 +95,8 @@ public class SpellAbilityPickerSimulationTest extends SimulationTest {
Game game = initAndCreateGame(); Game game = initAndCreateGame();
Player p = game.getPlayers().get(1); Player p = game.getPlayers().get(1);
addCard("Island", p); addCards("Island", 2, p);
addCard("Island", p); addCards("Forest", 3, p);
addCard("Forest", p);
addCard("Forest", p);
addCard("Forest", p);
Card tatyova = addCardToZone("Tatyova, Benthic Druid", p, ZoneType.Hand); Card tatyova = addCardToZone("Tatyova, Benthic Druid", p, ZoneType.Hand);
addCardToZone("Forest", p, ZoneType.Hand); addCardToZone("Forest", p, ZoneType.Hand);
@@ -169,10 +166,7 @@ public class SpellAbilityPickerSimulationTest extends SimulationTest {
Game game = initAndCreateGame(); Game game = initAndCreateGame();
Player p = game.getPlayers().get(1); Player p = game.getPlayers().get(1);
addCard("Mountain", p); addCards("Mountain", 4, p);
addCard("Mountain", p);
addCard("Mountain", p);
addCard("Mountain", p);
Card spell = addCardToZone("Fiery Confluence", p, ZoneType.Hand); Card spell = addCardToZone("Fiery Confluence", p, ZoneType.Hand);
Player opponent = game.getPlayers().get(0); Player opponent = game.getPlayers().get(0);
@@ -198,10 +192,7 @@ public class SpellAbilityPickerSimulationTest extends SimulationTest {
Game game = initAndCreateGame(); Game game = initAndCreateGame();
Player p = game.getPlayers().get(1); Player p = game.getPlayers().get(1);
addCard("Mountain", p); addCards("Mountain", 4, p);
addCard("Mountain", p);
addCard("Mountain", p);
addCard("Mountain", p);
Card spell = addCardToZone("Fiery Confluence", p, ZoneType.Hand); Card spell = addCardToZone("Fiery Confluence", p, ZoneType.Hand);
Player opponent = game.getPlayers().get(0); Player opponent = game.getPlayers().get(0);
@@ -226,8 +217,7 @@ public class SpellAbilityPickerSimulationTest extends SimulationTest {
Game game = initAndCreateGame(); Game game = initAndCreateGame();
Player p = game.getPlayers().get(1); Player p = game.getPlayers().get(1);
addCard("Mountain", p); addCards("Mountain", 2, p);
addCard("Mountain", p);
Card spell = addCardToZone("Arc Trail", p, ZoneType.Hand); Card spell = addCardToZone("Arc Trail", p, ZoneType.Hand);
Player opponent = game.getPlayers().get(0); Player opponent = game.getPlayers().get(0);
@@ -289,8 +279,7 @@ public class SpellAbilityPickerSimulationTest extends SimulationTest {
Game game = initAndCreateGame(); Game game = initAndCreateGame();
Player p = game.getPlayers().get(1); Player p = game.getPlayers().get(1);
addCard("Mountain", p); addCards("Mountain", 2, p);
addCard("Mountain", p);
Card abbot = addCardToZone("Abbot of Keral Keep", p, ZoneType.Hand); Card abbot = addCardToZone("Abbot of Keral Keep", p, ZoneType.Hand);
addCardToZone("Lightning Bolt", p, ZoneType.Hand); addCardToZone("Lightning Bolt", p, ZoneType.Hand);
// Note: This assumes the top of library is revealed. If the AI is made // Note: This assumes the top of library is revealed. If the AI is made
@@ -321,9 +310,7 @@ public class SpellAbilityPickerSimulationTest extends SimulationTest {
Game game = initAndCreateGame(); Game game = initAndCreateGame();
Player p = game.getPlayers().get(1); Player p = game.getPlayers().get(1);
addCard("Mountain", p); addCards("Mountain", 3, p);
addCard("Mountain", p);
addCard("Mountain", p);
Card abbot = addCardToZone("Abbot of Keral Keep", p, ZoneType.Hand); Card abbot = addCardToZone("Abbot of Keral Keep", p, ZoneType.Hand);
// Note: This assumes the top of library is revealed. If the AI is made // Note: This assumes the top of library is revealed. If the AI is made
// smarter to not assume that, then this test can be updated to have // smarter to not assume that, then this test can be updated to have
@@ -426,8 +413,7 @@ public class SpellAbilityPickerSimulationTest extends SimulationTest {
Card blocker = addCard("Fugitive Wizard", opponent); Card blocker = addCard("Fugitive Wizard", opponent);
Card attacker1 = addCard("Dwarven Trader", p); Card attacker1 = addCard("Dwarven Trader", p);
attacker1.setSickness(false); attacker1.setSickness(false);
addCard("Swamp", p); addCards("Swamp", 2, p);
addCard("Swamp", p);
addCardToZone("Doom Blade", p, ZoneType.Hand); addCardToZone("Doom Blade", p, ZoneType.Hand);
game.getPhaseHandler().devModeSet(PhaseType.MAIN1, p); game.getPhaseHandler().devModeSet(PhaseType.MAIN1, p);
@@ -455,9 +441,7 @@ public class SpellAbilityPickerSimulationTest extends SimulationTest {
Player opponent = game.getPlayers().get(0); Player opponent = game.getPlayers().get(0);
addCardToZone("Chaos Warp", p, ZoneType.Hand); addCardToZone("Chaos Warp", p, ZoneType.Hand);
addCard("Mountain", p); addCards("Mountain", 3, p);
addCard("Mountain", p);
addCard("Mountain", p);
addCard("Plains", opponent); addCard("Plains", opponent);
addCard("Mountain", opponent); addCard("Mountain", opponent);
@@ -489,8 +473,7 @@ public class SpellAbilityPickerSimulationTest extends SimulationTest {
Game game = initAndCreateGame(); Game game = initAndCreateGame();
Player p = game.getPlayers().get(1); Player p = game.getPlayers().get(1);
addCard("Island", p); addCards("Forest", 2, p);
addCard("Island", p);
addCardToZone("Counterspell", p, ZoneType.Hand); addCardToZone("Counterspell", p, ZoneType.Hand);
addCardToZone("Unsummon", p, ZoneType.Hand); addCardToZone("Unsummon", p, ZoneType.Hand);
@@ -605,4 +588,74 @@ public class SpellAbilityPickerSimulationTest extends SimulationTest {
// Still, this test case exercises the code path and ensures we don't crash in this case. // Still, this test case exercises the code path and ensures we don't crash in this case.
AssertJUnit.assertEquals(1, picker.getNumSimulations()); AssertJUnit.assertEquals(1, picker.getNumSimulations());
} }
@Test
public void threeDistinctTargetSpell() {
Game game = initAndCreateGame();
Player p = game.getPlayers().get(1);
Player opponent = game.getPlayers().get(0);
addCardToZone("Incremental Growth", p, ZoneType.Hand);
addCards("Forest", 5, p);
addCard("Forest Bear", p);
addCard("Flying Men", opponent);
addCard("Runeclaw Bear", p);
addCard("Water Elemental", opponent);
addCard("Grizzly Bears", p);
game.getPhaseHandler().devModeSet(PhaseType.MAIN2, p);
game.getAction().checkStateEffects(true);
SpellAbilityPicker picker = new SpellAbilityPicker(game, p);
SpellAbility sa = picker.chooseSpellAbilityToPlay(null);
AssertJUnit.assertNotNull(sa);
MultiTargetSelector.Targets targets = picker.getPlan().getSelectedDecision().targets;
AssertJUnit.assertEquals(3, targets.size());
AssertJUnit.assertTrue(targets.toString().contains("Forest Bear"));
AssertJUnit.assertTrue(targets.toString().contains("Runeclaw Bear"));
AssertJUnit.assertTrue(targets.toString().contains("Grizzly Bear"));
// Expected 5*4*3=60 iterations (5 choices for first target, 4 for next, 3 for last.)
AssertJUnit.assertEquals(60, picker.getNumSimulations());
}
@Test
public void threeDistinctTargetSpellCantBeCast() {
Game game = initAndCreateGame();
Player p = game.getPlayers().get(1);
Player opponent = game.getPlayers().get(0);
addCardToZone("Incremental Growth", p, ZoneType.Hand);
addCards("Forest", 5, p);
addCard("Forest Bear", p);
addCard("Flying Men", opponent);
game.getPhaseHandler().devModeSet(PhaseType.MAIN2, p);
game.getAction().checkStateEffects(true);
SpellAbilityPicker picker = new SpellAbilityPicker(game, p);
SpellAbility sa = picker.chooseSpellAbilityToPlay(null);
AssertJUnit.assertNull(sa);
}
@Test
public void correctTargetChoicesWithTwoTargetSpell() {
Game game = initAndCreateGame();
Player p = game.getPlayers().get(1);
Player opponent = game.getPlayers().get(0);
addCardToZone("Rites of Reaping", p, ZoneType.Hand);
addCard("Swamp", p);
addCards("Forest", 5, p);
addCard("Flying Men", opponent);
addCard("Forest Bear", p);
addCard("Water Elemental", opponent);
game.getPhaseHandler().devModeSet(PhaseType.MAIN2, p);
game.getAction().checkStateEffects(true);
SpellAbilityPicker picker = new SpellAbilityPicker(game, p);
SpellAbility sa = picker.chooseSpellAbilityToPlay(null);
AssertJUnit.assertNotNull(sa);
MultiTargetSelector.Targets targets = picker.getPlan().getSelectedDecision().targets;
AssertJUnit.assertEquals(2, targets.size());
AssertJUnit.assertTrue(targets.toString().contains("Forest Bear"));
AssertJUnit.assertTrue(targets.toString().contains("Flying Men"));
}
} }