diff --git a/Game.py b/Game.py index e69de29..967ea44 100644 --- a/Game.py +++ b/Game.py @@ -0,0 +1,73 @@ +import numpy as np +from State import State + +class Game: + def __init__(self, distances: np.array): + """ + The game is a TSP. A TSP can be described as a complete graph where each node is numbered with + numbers 0, ..., n_nodes -1 and each edge (i, j) has as attribute the distance between nodes i and j. + That is the Game is fully described by providing the distance matrix between nodes. + + :param distances: np.array, distance matrix. + """ + self.distances = distances + self.n_nodes = self.distances.shape[0] + self.all_actions = [x for x in range(self.n_nodes)] + + def available_actions(self, state: State) -> list[int]: + """ + Returns the list of available action at a given state. + + :param state: State, state of the game. + :return: list, all available actions. + """ + return [a for a in range(self.n_nodes) if a not in state.visited_nodes] + + def game_over(self, state: State) -> bool: + """ + True if the state is a final state, i.e. the game is over, False otherwise. + + :param state: State, state of the game. + :return: bool, True for game over, False otherwise. + """ + return self.n_nodes == len(state.visited_nodes) + + def step(self, state: State, action: int) -> State: + """ + Gives the new state achieved by performing the passed action from state. + + :param state: State, state of the game. + :param action: int, action to be performed. Should be a feasible action. + :return: State, the new state reached. + """ + visited_nodes = state.visited_nodes + [action] + return State(action, self.distances[action, :], visited_nodes) + + def get_objective(self, state: State) -> float: + """ + Give the lenght of the tour if the state is a final state, else raise an error. + :param state: State, state of the game. + :return: float, total length of the tour. + """ + if not self.game_over(state): + raise Exception("The objective of a partial solution is trying to be computed.") + + obj = self.distances[state.current_node, 0] + for node_idx in range(self.n_nodes - 1): + obj += self.distances[state.visited_nodes[node_idx], state.visited_nodes[node_idx + 1]] + + return obj + + def score(self, state: State, opponent_objective: float) -> int: + """ + Return if the game is won or not by the player. 1 if the game is won, -1 otherwise. + + :param state: State, state of the game. + :param opponent_objective: float, lenght of the tour found by the opponent. + :return: int, 1 if the lenght of the tour of the player is less or equal to the opponent's, -1 otherwise. + """ + player_objective = self.get_objective(state) + if player_objective <= opponent_objective: + return 1 + return -1 + diff --git a/MCTS.py b/MCTS.py new file mode 100644 index 0000000..4b20601 --- /dev/null +++ b/MCTS.py @@ -0,0 +1,109 @@ +from collections import defaultdict + +import numpy as np + +from State import State +from Game import Game +from NeuralNetwork import NN + +CPUCT = 1 +NUM_SIMULATIONS = 100 +TEMPERATURE = 1 + + +class MCTS: + def __init__(self, game: Game, nn: NN, opponent_objective: float): + """ + MCTS class primarly has 2 methods. The search method that, starting from the input state, runs a simulation + of the game choosing the actions by their Upper Confidence Bound value and updates the Ps, Qsa and Nsa values. + + A second method, get policy, is used to find an approximation for the optimal policy. This approximated optimal + policy is used, during training time, as a ground truth to train the Neural Network. This approximated policy + is computed starting from the Nsa values collected through many iterations of the search method. + + Ps is the prior distriution, given by a trained Neural Network. + Qsa is the Q-value of the state-action pair, approximated using the value of the games played in many iterations + of the search method. + Nsa is the number of time the pair state-action in played. A high value in Nsa means that state-action pair + was considered promising many times, a low value means it wasn't. + + :param game: Game, game that needs to be played. + :param nn: NeuralNetwork, neural net that gives the prior distribution. + :param opponent_objective: float, value to beat to decide if a game is won or not. + """ + self.game = game + self.nn = nn + self.opponent_objective = opponent_objective + self.Ps = {} + self.Qsa = defaultdict(float) + self.Nsa = defaultdict(int) + + def is_visited(self, state: State) -> bool: + """Helper method to check if a state was already visited through search iterations.""" + return state.visited_nodes in self.Ps + + def search(self, state: State) -> float: + """ + The search method collects values for Ps, Nsa and Qsa by playing a game until termination starting from state. + + This is a recursive method. Return the value of the game if the state reached is a terminal state for the + game or if the state is unvisited (leaf node). + Else, decide an action to perform, update the state by performing such action and then apply the search + method to the new state. Then update the Nsa and Qsa value for the child. + + The action is chosen by computing the Upper Confidence Bound of each action and selecting the one with highest + value. For computing the UCB, a probability distribution of the actions is needed. This probability distribution + Ps is predicted by a Neural Network nn. + + Since only for the final state-action an exact value of the game is available, the value of each state + is also approximated using the nn. This approximated value is used to compute an approximation of the Q-value + Qsa. This Qsa is also updated at each iteration when a new approximation from a child node is available. + + :param state: State, state from which the tree in explored. + :return: float, value of the game. + """ + s = state.visited_nodes + if self.game.game_over(state): + return self.game.score(state, self.opponent_objective) + + if not self.is_visited(state): + self.Ps[s], v = self.nn.predict(state) + return v + + valid_moves = self.game.available_actions(state) + + best_u = -float("inf") + best_a = None + for a in valid_moves: + u = self.Qsa[s, a] + (CPUCT * self.Ps[s][a] * np.sqrt(sum([self.Nsa[s, b] for b in valid_moves])) / + (self.Nsa[s, a] + 1)) + if u >= best_u: + best_u = u + best_a = a + + new_state = self.game.step(state, best_a) + v = self.search(new_state) + self.Qsa[(s, best_a)] = ( + (self.Nsa[(s, best_a)] * self.Qsa[(s, best_a)] + v) / + (self.Nsa[(s, best_a)] + 1) + ) + self.Nsa[(s, best_a)] += 1 + + return v + + def get_policy(self, state: State) -> list[float]: + for _ in range(NUM_SIMULATIONS): + self.search(state) + + s = state.visited_nodes + + if TEMPERATURE == 0: + policy = [0] * self.game.n_nodes + argmax = np.argmax([self.Nsa[s, a] for a in self.game.all_actions]) + policy[argmax] = 1 + else: + policy = [self.Nsa[s, a] ** (1 / TEMPERATURE) for a in self.game.all_actions] + sum_policy = sum(policy) + policy /= sum_policy + + return policy diff --git a/NeuralNetwork.py b/NeuralNetwork.py new file mode 100644 index 0000000..3d99183 --- /dev/null +++ b/NeuralNetwork.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import numpy as np + +class NN(nn.Module): + def __init__(self, input_len): + super(NN, self).__init__() + self.fc1 = nn.Linear(input_len, 64) + self.fc2 = nn.Linear(64, input_len) + self.fc3 = nn.Linear(64, 1) + + + def forward(self, input): + x = F.relu(self.fc1(input)) + p = F.softmax(self.fc2(x), dim=0) + v = torch.tanh(self.fc3(x)) + return p, v + + def train(self, examples): + """ + examples: list of examples, each example is of form (board, pi, v) + """ + batch_size = 10 + optimizer = optim.Adam(self.parameters()) + + for epoch in range(100): + batch_count = int(len(examples) / batch_size) + + for _ in range(batch_count): + sample_ids = np.random.randint(len(examples), size=batch_size) + states, pis, vs = list(zip(*[examples[i] for i in sample_ids])) + states = torch.FloatTensor(np.array(states).astype(np.float64)) + target_pis = torch.FloatTensor(np.array(pis)) + target_vs = torch.FloatTensor(np.array(vs).astype(np.float64)) + + # compute output + out_pi, out_v = self(states) + l_pi = self.loss_pi(target_pis, out_pi) + l_v = self.loss_v(target_vs, out_v) + total_loss = l_pi + l_v + + + # compute gradient and do SGD step + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + def loss_pi(self, targets, outputs): + return -torch.sum(targets * outputs) / targets.size()[0] + + def loss_v(self, targets, outputs): + return torch.sum((targets - outputs.view(-1)) ** 2) / targets.size()[0] diff --git a/State.py b/State.py new file mode 100644 index 0000000..a99b4aa --- /dev/null +++ b/State.py @@ -0,0 +1,18 @@ +import numpy as np +import torch + +class State: + def __init__(self, current_node: int, node_distances: np.array, visited_nodes: list[int]): + self.current_node = current_node + self.node_distances = node_distances + self.visited_nodes = visited_nodes + + def __len__(self): + return self.node_distances.shape[0] + 1 + + def to_tensor(self): + as_list = [self.current_node] + self.node_distances.tolist() + return torch.tensor(as_list) + + def get_identifier(self): + return self.visited_nodes diff --git a/unittest/test_game.py b/unittest/test_game.py new file mode 100644 index 0000000..610b238 --- /dev/null +++ b/unittest/test_game.py @@ -0,0 +1,110 @@ +import unittest + +import numpy as np + +from Game import Game +from State import State + +np.random.seed(42) + +n_nodes = 4 +A = np.random.rand(n_nodes, n_nodes) + +class TestAvailableActions(unittest.TestCase): + def setUp(self) -> None: + self.distances = (A.T * A) / 2 + self.game = Game(self.distances) + + def test_available_actions(self): + state = State(1, self.distances[:, 1], [0, 1]) + avail_expected = [2, 3] + avail_returned = self.game.available_actions(state) + self.assertEqual(avail_returned, avail_expected) + + def test_no_actions(self): + state = State(1, self.distances[:, 1], [0, 3, 2, 1]) + avail_expected = [] + avail_returned = self.game.available_actions(state) + self.assertEqual(avail_returned, avail_expected) + +class TestGameOver(unittest.TestCase): + def setUp(self) -> None: + self.distances = (A.T * A) / 2 + self.game = Game(self.distances) + + self.fixtures = ( + (State(1, self.distances[:, 1], [0, 1]), False), + (State(1, self.distances[:, 1], [0, 3, 2, 1]), True), + ) + + def test_fixtures(self): + for state, expected in self.fixtures: + if expected: + self.assertTrue(self.game.game_over(state)) + else: + self.assertFalse(self.game.game_over(state)) + + +class TestStep(unittest.TestCase): + def setUp(self) -> None: + self.distances = (A.T * A) / 2 + self.game = Game(self.distances) + + self.fixtures = ( + (State(1, self.distances[:, 1], [0, 1]), 2, State(2, self.distances[:, 2], [0, 1, 2])), + (State(1, self.distances[:, 1], [0, 1]), 3, State(3, self.distances[:, 3], [0, 1, 3])), + ) + + def test_fixtures(self): + for state, action, expected in self.fixtures: + new_state = self.game.step(state, action) + self.assertEqual(new_state.current_node, expected.current_node) + self.assertTrue(all(new_state.node_distances == expected.node_distances)) + self.assertEqual(new_state.visited_nodes, expected.visited_nodes) + + +class TestGetObjective(unittest.TestCase): + def setUp(self) -> None: + self.distances = np.array( + [ + [1, 2, 3, 4], + [2, 3, 4, 5], + [3, 4, 5, 6], + [4, 5, 6, 7] + ] + ) + self.game = Game(self.distances) + self.state = State(3, self.distances[:, 3], [0, 1, 2, 3]) + + def test_objective(self): + obj_returend = self.game.get_objective(self.state) + obj_expected = 16 + self.assertEqual(obj_expected, obj_returend) + + +class TestScore(unittest.TestCase): + def setUp(self) -> None: + self.distances = np.array( + [ + [1, 2, 3, 4], + [2, 3, 4, 5], + [3, 4, 5, 6], + [4, 5, 6, 7] + ] + ) + self.game = Game(self.distances) + self.state = State(3, self.distances[:, 3], [0, 1, 2, 3]) + self.fixtures = ( + (15, -1), + (16, 1), + (17, 1), + ) + + def test_fixtures(self): + for opponent_obj, score_expected in self.fixtures: + score_returend = self.game.score(self.state, opponent_obj) + self.assertEqual(score_returend, score_expected) + + +if __name__ == "__main__": + unittest.main()