Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions pearl/policy_learners/contextual_bandits/neural_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,20 +132,27 @@ def act(
"""
Args:
subjective_state: state will be applied to different action vectors in action_space
action_space: contains a list of action vector, currenly only support static space
action_space: contains a list of action vector, currently only support static space
Return:
action index chosen given state and action vectors
"""
assert isinstance(available_action_space, DiscreteActionSpace)
# It doesnt make sense to call act if we are not working with action vector
# It doesn't make sense to call act if we are not working with action vector
action_count = available_action_space.n
new_feature = concatenate_actions_to_state(
subjective_state=subjective_state,
action_space=available_action_space,
state_features_only=self._state_features_only,
action_representation_module=self.action_representation_module,
)
values = self.model(new_feature).squeeze(-1)
batch_size = new_feature.shape[0]
feature_dim = new_feature.shape[-1]
# Flatten the action dimension before model evaluation
values = (
self.model(new_feature.reshape(-1, feature_dim))
.reshape(batch_size, -1)
.squeeze(-1)
)
# batch_size * action_count
assert values.numel() == new_feature.shape[0] * action_count
return self.exploration_module.act(
Expand Down
7 changes: 5 additions & 2 deletions pearl/replay_buffers/tensor_based_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,12 @@ def _process_non_optional_single_state(

def _process_single_action(self, action: Action) -> torch.Tensor:
if isinstance(action, torch.Tensor):
return action.to(get_default_device()).clone().detach().unsqueeze(0)
tensor = action.to(get_default_device()).clone().detach()
else:
return torch.tensor(action).unsqueeze(0)
tensor = torch.tensor(action)
if tensor.ndim <= 1:
tensor = tensor.unsqueeze(0)
return tensor

def _process_single_reward(self, reward: Reward) -> torch.Tensor:
return torch.tensor([reward])
Expand Down
2 changes: 1 addition & 1 deletion test/unit/with_pytorch/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from unittest import TestCase

import torch
from pearl.test.unit.with_pytorch.test_agent import TestAgentWithPyTorch
from test.unit.with_pytorch.test_agent import TestAgentWithPyTorch
from torch import nn


Expand Down
Loading