@@ -40,6 +40,7 @@ def __init__(
4040 finetuning_model : str | None = None ,
4141 launch_kwargs : dict [str , Any ] | None = None ,
4242 train_kwargs : dict [str , Any ] | None = None ,
43+ use_developer_role : bool = False ,
4344 ** kwargs ,
4445 ):
4546 """
@@ -77,6 +78,7 @@ def __init__(
7778 self .finetuning_model = finetuning_model
7879 self .launch_kwargs = launch_kwargs or {}
7980 self .train_kwargs = train_kwargs or {}
81+ self .use_developer_role = use_developer_role
8082 self ._warned_zero_temp_rollout = False
8183
8284 # Handle model-specific configuration for different model families
@@ -131,6 +133,11 @@ def forward(self, prompt=None, messages=None, **kwargs):
131133 cache = kwargs .pop ("cache" , self .cache )
132134
133135 messages = messages or [{"role" : "user" , "content" : prompt }]
136+ if self .use_developer_role and self .model_type == "responses" :
137+ messages = [
138+ {** m , "role" : "developer" } if m .get ("role" ) == "system" else m
139+ for m in messages
140+ ]
134141 kwargs = {** self .kwargs , ** kwargs }
135142 self ._warn_zero_temp_rollout (kwargs .get ("temperature" ), kwargs .get ("rollout_id" ))
136143 if kwargs .get ("rollout_id" ) is None :
@@ -162,6 +169,11 @@ async def aforward(self, prompt=None, messages=None, **kwargs):
162169 cache = kwargs .pop ("cache" , self .cache )
163170
164171 messages = messages or [{"role" : "user" , "content" : prompt }]
172+ if self .use_developer_role and self .model_type == "responses" :
173+ messages = [
174+ {** m , "role" : "developer" } if m .get ("role" ) == "system" else m
175+ for m in messages
176+ ]
165177 kwargs = {** self .kwargs , ** kwargs }
166178 self ._warn_zero_temp_rollout (kwargs .get ("temperature" ), kwargs .get ("rollout_id" ))
167179 if kwargs .get ("rollout_id" ) is None :
0 commit comments