Skip to content

Commit 6c00645

Browse files
[CI][Pooling] Stabilize ModernBERT test (vllm-project#32909)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
1 parent b781eea commit 6c00645

1 file changed

Lines changed: 27 additions & 0 deletions

File tree

tests/models/language/pooling/test_token_classification.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import random
4+
5+
import numpy as np
36
import pytest
47
import torch
58
from transformers import AutoModelForTokenClassification
@@ -8,6 +11,20 @@
811
from vllm.platforms import current_platform
912

1013

14+
@pytest.fixture(autouse=True)
15+
def seed_everything():
16+
"""Seed all random number generators for reproducibility."""
17+
seed = 0
18+
random.seed(seed)
19+
np.random.seed(seed)
20+
torch.manual_seed(seed)
21+
if torch.cuda.is_available():
22+
torch.cuda.manual_seed_all(seed)
23+
torch.backends.cudnn.deterministic = True
24+
torch.backends.cudnn.benchmark = False
25+
yield
26+
27+
1128
@pytest.mark.parametrize("model", ["boltuix/NeuroBERT-NER"])
1229
# The float32 is required for this tiny model to pass the test.
1330
@pytest.mark.parametrize("dtype", ["float"])
@@ -51,6 +68,7 @@ def test_bert_models(
5168

5269
@pytest.mark.parametrize("model", ["disham993/electrical-ner-ModernBERT-base"])
5370
@pytest.mark.parametrize("dtype", ["float"])
71+
@pytest.mark.flaky(reruns=3)
5472
@torch.inference_mode
5573
def test_modernbert_models(
5674
hf_runner,
@@ -59,6 +77,15 @@ def test_modernbert_models(
5977
model: str,
6078
dtype: str,
6179
) -> None:
80+
# NOTE: https://github.com/vllm-project/vllm/pull/32403
81+
# `disham993/electrical-ner-ModernBERT-base` is a randomly initialized
82+
# model, which can cause numerical precision variance and edge cases.
83+
# We use @flaky(reruns=3) to mitigate intermittent failures.
84+
print(
85+
f"\n[NOTE] Testing {model} (randomly initialized weights) - "
86+
"flaky tolerance enabled due to numerical precision variance."
87+
)
88+
6289
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
6390
vllm_outputs = vllm_model.token_classify(example_prompts)
6491

0 commit comments

Comments
 (0)