-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Description
Hello, I wanted to use the model to evaluate score for single files, based upon the given code in this repo I created a method to do that.
I want to verify if this is the correct approach?
def load_model(checkpoint_path, device="cuda:0"):
"""Loads a pre-trained MalConvGCT model from a checkpoint."""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
mlgct = MalConvGCT(channels=256, window_size=256, stride=64,)
x = torch.load(checkpoint_path)
mlgct.load_state_dict(x['model_state_dict'], strict=False)
mlgct.eval() # Set model to evaluation mode
return mlgct
def get_score(model, file_path, max_len=settings.MAX_FILE_LEN_MALCONV2, padding_char=256):
"""Takes a binary file and returns the predicted score (malicious or benign)."""
# Load the model
base_model = load_model(settings.MALCONV2_CHECKPOINT_FILE, device=settings.DEVICE)
# Read the binary file
with open(file_path, 'rb') as f:
data = f.read(max_len)
x = np.frombuffer(data, dtype=np.uint8).astype(np.int16)+1
# Convert to tensor
input_tensor = torch.tensor(x, dtype=torch.uint8).unsqueeze(0).to(settings.DEVICE) # Add batch dimension
# Run the base_model and get the prediction score
with torch.no_grad():
outputs, _, _ = base_model(input_tensor)
_, predicted = torch.max(outputs.data, 1)
predicted = predicted.to(settings.DEVICE)
score = F.softmax(outputs, dim=-1).data[:,1].detach().cpu().numpy().ravel()
return score[0]Metadata
Metadata
Assignees
Labels
No labels