Skip to content

Commit 0664b3b

Browse files
Fix coding style in dspy/adapters (stanfordnlp#8155)
* enable style check * Fix style for dspy/adapters * remove extra
1 parent 1790410 commit 0664b3b

7 files changed

Lines changed: 54 additions & 45 deletions

File tree

dspy/adapters/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from dspy.adapters.base import Adapter
22
from dspy.adapters.chat_adapter import ChatAdapter
33
from dspy.adapters.json_adapter import JSONAdapter
4-
from dspy.adapters.types import Image, History
54
from dspy.adapters.two_step_adapter import TwoStepAdapter
5+
from dspy.adapters.types import History, Image
66

77
__all__ = [
88
"Adapter",
99
"ChatAdapter",
10-
"JSONAdapter",
11-
"Image",
1210
"History",
11+
"Image",
12+
"JSONAdapter",
1313
"TwoStepAdapter",
1414
]

dspy/adapters/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def format_assistant_message_content(
212212
self,
213213
signature: Type[Signature],
214214
outputs: dict[str, Any],
215-
missing_field_message: str = None,
215+
missing_field_message: Optional[str] = None,
216216
) -> str:
217217
"""Format the assistant message content.
218218

dspy/adapters/chat_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,6 @@ def format_finetune_data(
216216
assistant_message_content = self.format_assistant_message_content( # returns a string, without the role
217217
signature=signature, outputs=outputs
218218
)
219-
assistant_message = dict(role="assistant", content=assistant_message_content)
219+
assistant_message = {"role": "assistant", "content": assistant_message_content}
220220
messages = system_user_messages + [assistant_message]
221-
return dict(messages=messages)
221+
return {"messages": messages}

dspy/adapters/json_adapter.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import json
2-
import regex
32
import logging
43
from typing import Any, Dict, Type, get_origin
54

65
import json_repair
76
import litellm
87
import pydantic
8+
import regex
99
from pydantic.fields import FieldInfo
1010

1111
from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName
@@ -29,7 +29,7 @@ def _has_open_ended_mapping(signature: SignatureMeta) -> bool:
2929
such as dict[str, Any]. Structured Outputs require explicit properties, so such fields
3030
are incompatible.
3131
"""
32-
for name, field in signature.output_fields.items():
32+
for field in signature.output_fields.values():
3333
annotation = field.annotation
3434
if get_origin(annotation) is dict:
3535
return True
@@ -121,9 +121,9 @@ def format_assistant_message_content(
121121
return self.format_field_with_value(fields_with_values, role="assistant")
122122

123123
def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
124-
pattern = r'\{(?:[^{}]|(?R))*\}'
125-
match = regex.search(pattern, completion, regex.DOTALL)
126-
if match:
124+
pattern = r"\{(?:[^{}]|(?R))*\}"
125+
match = regex.search(pattern, completion, regex.DOTALL)
126+
if match:
127127
completion = match.group(0)
128128
fields = json_repair.loads(completion)
129129

@@ -196,10 +196,14 @@ def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[py
196196
fields[name] = (annotation, default)
197197

198198
# Build the model with extra fields forbidden.
199-
Model = pydantic.create_model("DSPyProgramOutputs", **fields, __config__=type("Config", (), {"extra": "forbid"}))
199+
pydantic_model = pydantic.create_model(
200+
"DSPyProgramOutputs",
201+
**fields,
202+
__config__=type("Config", (), {"extra": "forbid"}),
203+
)
200204

201205
# Generate the initial schema.
202-
schema = Model.model_json_schema()
206+
schema = pydantic_model.model_json_schema()
203207

204208
# Remove any DSPy-specific metadata.
205209
for prop in schema.get("properties", {}).values():
@@ -208,9 +212,9 @@ def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[py
208212
def enforce_required(schema_part: dict):
209213
"""
210214
Recursively ensure that:
211-
- for any object schema, a "required" key is added with all property names (or [] if no properties)
212-
- additionalProperties is set to False regardless of the previous value.
213-
- the same enforcement is run for nested arrays and definitions.
215+
- for any object schema, a "required" key is added with all property names (or [] if no properties)
216+
- additionalProperties is set to False regardless of the previous value.
217+
- the same enforcement is run for nested arrays and definitions.
214218
"""
215219
if schema_part.get("type") == "object":
216220
props = schema_part.get("properties")
@@ -237,6 +241,6 @@ def enforce_required(schema_part: dict):
237241
enforce_required(schema)
238242

239243
# Override the model's JSON schema generation to return our precomputed schema.
240-
Model.model_json_schema = lambda *args, **kwargs: schema
244+
pydantic_model.model_json_schema = lambda *args, **kwargs: schema
241245

242-
return Model
246+
return pydantic_model

dspy/adapters/two_step_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Type
1+
from typing import Any, Optional, Type
22

33
from dspy.adapters.base import Adapter
44
from dspy.adapters.chat_adapter import ChatAdapter
@@ -175,7 +175,7 @@ def format_assistant_message_content(
175175
self,
176176
signature: Type[Signature],
177177
outputs: dict[str, Any],
178-
missing_field_message: str = None,
178+
missing_field_message: Optional[str] = None,
179179
) -> str:
180180
parts = []
181181

dspy/adapters/types/image.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import base64
22
import io
3+
import mimetypes
34
import os
5+
import re
46
from typing import Any, Dict, List, Union
57
from urllib.parse import urlparse
6-
import re
7-
import mimetypes
88

99
import pydantic
1010
import requests
@@ -19,14 +19,14 @@
1919

2020
class Image(pydantic.BaseModel):
2121
url: str
22-
22+
2323
model_config = {
24-
'frozen': True,
25-
'str_strip_whitespace': True,
26-
'validate_assignment': True,
27-
'extra': 'forbid',
24+
"frozen": True,
25+
"str_strip_whitespace": True,
26+
"validate_assignment": True,
27+
"extra": "forbid",
2828
}
29-
29+
3030
@pydantic.model_validator(mode="before")
3131
@classmethod
3232
def validate_input(cls, values):
@@ -52,7 +52,7 @@ def from_file(cls, file_path: str):
5252
return cls(url=encode_image(file_path))
5353

5454
@classmethod
55-
def from_PIL(cls, pil_image):
55+
def from_PIL(cls, pil_image): # noqa: N802
5656
return cls(url=encode_image(pil_image))
5757

5858
@pydantic.model_serializer()
@@ -66,9 +66,10 @@ def __repr__(self):
6666
if "base64" in self.url:
6767
len_base64 = len(self.url.split("base64,")[1])
6868
image_type = self.url.split(";")[0].split("/")[-1]
69-
return f"Image(url=data:image/{image_type};base64,<IMAGE_BASE_64_ENCODED({str(len_base64)})>)"
69+
return f"Image(url=data:image/{image_type};base64,<IMAGE_BASE_64_ENCODED({len_base64!s})>)"
7070
return f"Image(url='{self.url}')"
7171

72+
7273
def is_url(string: str) -> bool:
7374
"""Check if a string is a valid URL."""
7475
try:
@@ -162,7 +163,7 @@ def _encode_image_from_url(image_url: str) -> str:
162163
return f"data:{mime_type};base64,{encoded_data}"
163164

164165

165-
def _encode_pil_image(image: 'PILImage') -> str:
166+
def _encode_pil_image(image: "PILImage") -> str:
166167
"""Encode a PIL Image object to a base64 data URI."""
167168
buffered = io.BytesIO()
168169
file_format = image.format or "PNG"
@@ -197,6 +198,7 @@ def is_image(obj) -> bool:
197198
return True
198199
return False
199200

201+
200202
def try_expand_image_tags(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
201203
"""Try to expand image tags in the messages."""
202204
for message in messages:
@@ -205,43 +207,44 @@ def try_expand_image_tags(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]
205207
message["content"] = expand_image_tags(message["content"])
206208
return messages
207209

210+
208211
def expand_image_tags(text: str) -> Union[str, List[Dict[str, Any]]]:
209-
"""Expand image tags in the text. If there are any image tags,
212+
"""Expand image tags in the text. If there are any image tags,
210213
turn it from a content string into a content list of texts and image urls.
211-
214+
212215
Args:
213216
text: The text content that may contain image tags
214-
217+
215218
Returns:
216219
Either the original string if no image tags, or a list of content dicts
217220
with text and image_url entries
218221
"""
219222
image_tag_regex = r'"?<DSPY_IMAGE_START>(.*?)<DSPY_IMAGE_END>"?'
220-
223+
221224
# If no image tags, return original text
222225
if not re.search(image_tag_regex, text):
223226
return text
224-
227+
225228
final_list = []
226229
remaining_text = text
227-
230+
228231
while remaining_text:
229232
match = re.search(image_tag_regex, remaining_text)
230233
if not match:
231234
if remaining_text.strip():
232235
final_list.append({"type": "text", "text": remaining_text.strip()})
233236
break
234-
237+
235238
# Get text before the image tag
236-
prefix = remaining_text[:match.start()].strip()
239+
prefix = remaining_text[: match.start()].strip()
237240
if prefix:
238241
final_list.append({"type": "text", "text": prefix})
239-
242+
240243
# Add the image
241244
image_url = match.group(1)
242245
final_list.append({"type": "image_url", "image_url": {"url": image_url}})
243-
246+
244247
# Update remaining text
245-
remaining_text = remaining_text[match.end():].strip()
246-
248+
remaining_text = remaining_text[match.end() :].strip()
249+
247250
return final_list

dspy/adapters/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def parse_value(value, annotation):
154154

155155
if v in allowed:
156156
return v
157-
157+
158158
raise ValueError(f"{value!r} is not one of {allowed!r}")
159159

160160
if not isinstance(value, str):
@@ -174,6 +174,7 @@ def parse_value(value, annotation):
174174
return str(candidate)
175175
raise
176176

177+
177178
def get_annotation_name(annotation):
178179
origin = get_origin(annotation)
179180
args = get_args(annotation)
@@ -193,6 +194,7 @@ def get_annotation_name(annotation):
193194
args_str = ", ".join(get_annotation_name(a) for a in args)
194195
return f"{get_annotation_name(origin)}[{args_str}]"
195196

197+
196198
def get_field_description_string(fields: dict) -> str:
197199
field_descriptions = []
198200
for idx, (k, v) in enumerate(fields.items()):
@@ -220,7 +222,7 @@ def _format_input_list_field_value(value: List[Any]) -> str:
220222
if len(value) == 1:
221223
return _format_blob(value[0])
222224

223-
return "\n".join([f"[{idx+1}] {_format_blob(txt)}" for idx, txt in enumerate(value)])
225+
return "\n".join([f"[{idx + 1}] {_format_blob(txt)}" for idx, txt in enumerate(value)])
224226

225227

226228
def _format_blob(blob: str) -> str:

0 commit comments

Comments
 (0)