Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions demo/demo_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pathlib import Path
from PIL import Image
import requests
import argparse
import shutil # Import shutil for cleanup

# Local tool imports
Expand Down Expand Up @@ -348,7 +349,10 @@ def process_image_inference(session_state, test_image_input, file_input,
})

total_elements = len(pdf_result['combined_cells_data'])
info_text = f"**PDF Information:**\n- Total Pages: {pdf_result['total_pages']}\n- Server: {current_config['ip']}:{current_config['port_vllm']}\n- Total Detected Elements: {total_elements}\n- Session ID: {pdf_result['session_id']}"
is_layout_prompt = prompt_mode in ["prompt_layout_all_en", "prompt_layout_only_en", "prompt_grounding_ocr"]
info_text = f"**PDF Information:**\n- Total Pages: {pdf_result['total_pages']}\n- Server: {current_config['ip']}:{current_config['port_vllm']}\n- Session ID: {pdf_result['session_id']}"
if is_layout_prompt:
info_text = info_text.replace("\n- Session ID:", f"\n- Total Detected Elements: {total_elements}\n- Session ID:")

current_page_layout_image = preview_image
current_page_json = ""
Expand Down Expand Up @@ -400,8 +404,10 @@ def process_image_inference(session_state, test_image_input, file_input,
'result_paths': parse_result['result_paths']
})

num_elements = len(parse_result['cells_data']) if parse_result['cells_data'] else 0
info_text = f"**Image Information:**\n- Original Size: {original_image.width} x {original_image.height}\n- Model Input Size: {parse_result['input_width']} x {parse_result['input_height']}\n- Server: {current_config['ip']}:{current_config['port_vllm']}\n- Detected {num_elements} layout elements\n- Session ID: {parse_result['session_id']}"
num_elements = len(parse_result['cells_data']) if parse_result['cells_data'] else None
info_text = f"**Image Information:**\n- Original Size: {original_image.width} x {original_image.height}\n- Model Input Size: {parse_result['input_width']} x {parse_result['input_height']}\n- Server: {current_config['ip']}:{current_config['port_vllm']}\n- Session ID: {parse_result['session_id']}"
if num_elements is not None:
info_text = info_text.replace("\n- Session ID:", f"\n- Detected {num_elements} layout elements\n- Session ID:")

current_json = json.dumps(parse_result['cells_data'], ensure_ascii=False, indent=2) if parse_result['cells_data'] else ""

Expand All @@ -413,8 +419,11 @@ def process_image_inference(session_state, test_image_input, file_input,
for file in files:
if not file.endswith('.zip'): zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), parse_result['temp_dir']))

# Show layout image if created; otherwise fall back to original image (non-layout prompts or empty/invalid layout)
display_image = parse_result['layout_image'] if parse_result['layout_image'] is not None else original_image

return (
parse_result['layout_image'], info_text, parse_result['md_content'] or "No markdown content generated",
display_image, info_text, parse_result['md_content'] or "No markdown content generated",
md_content_raw, gr.update(value=download_zip_path, visible=bool(download_zip_path)),
None, current_json, session_state
)
Expand Down Expand Up @@ -716,11 +725,17 @@ def create_gradio_interface():

# ==================== Main Program ====================
if __name__ == "__main__":
import sys
port = int(sys.argv[1])
parser = argparse.ArgumentParser(description="Run dots.ocr Gradio demo")
# Port can be provided positionally for backward compatibility, or omitted to use default 7860
parser.add_argument("port", nargs="?", type=int, help="Port to run Gradio on (default: 7860)")
parser.add_argument("--share", action="store_true", help="Enable Gradio share URL (default: False)")
args = parser.parse_args()

port = args.port if args.port is not None else 7860
demo = create_gradio_interface()
demo.queue().launch(
server_name="0.0.0.0",
server_port=port,
debug=True
debug=True,
share=args.share
)
21 changes: 16 additions & 5 deletions demo/demo_gradio_annotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pathlib import Path
from PIL import Image
import requests
import argparse
from gradio_image_annotation import image_annotator

# Local utility imports
Expand Down Expand Up @@ -260,7 +261,8 @@ def process_image_inference_with_annotation(annotation_data, test_image_input,
if image is None:
return None, "Please select a test image or add an image in the annotation component", "", "", gr.update(value=None), ""
if bbox is None:
return "Please select a bounding box by mouse", "Please select a bounding box by mouse", "", "", gr.update(value=None)
# Ensure we return 6 outputs to match the interface bindings
return "Please select a bounding box by mouse", "Please select a bounding box by mouse", "", "", gr.update(value=None), ""

try:
# Process using DotsOCRParser, passing the bbox parameter
Expand Down Expand Up @@ -308,16 +310,18 @@ def process_image_inference_with_annotation(annotation_data, test_image_input,
)

# Handle the case where JSON parsing succeeds
num_elements = len(cells_data) if cells_data else 0
num_elements = len(cells_data) if cells_data else None
info_text = f"""
**Image Information:**
- Original Dimensions: {original_image.width} x {original_image.height}
- Processing Mode: {'Region OCR' if bbox else 'Full Image OCR'}
- Server: {current_config['ip']}:{current_config['port_vllm']}
- Detected {num_elements} layout elements
- Session ID: {parse_result['session_id']}
- Box Coordinates: {bbox if bbox else 'None'}
"""
# Only mention detected elements when cells_data is present (aligns with conditional layout behavior)
if num_elements is not None:
info_text = info_text.replace("\n- Session ID:", f"\n- Detected {num_elements} layout elements\n- Session ID:")

# Current page JSON output
current_json = ""
Expand Down Expand Up @@ -658,9 +662,16 @@ def create_gradio_interface():

# ==================== Main Program ====================
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run dots.ocr Gradio annotation demo")
parser.add_argument("port", nargs="?", type=int, help="Port to run Gradio on (default: 7861)")
parser.add_argument("--share", action="store_true", help="Enable Gradio share URL (default: False)")
args = parser.parse_args()

port = args.port if args.port is not None else 7861
demo = create_gradio_interface()
demo.queue().launch(
server_name="0.0.0.0",
server_port=7861, # Use a different port to avoid conflicts
debug=True
server_port=port, # Default different port to avoid conflicts
debug=True,
share=args.share
)
99 changes: 51 additions & 48 deletions demo/demo_streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,63 +109,66 @@ def get_image_input():



def process_and_display_results(output: str, image: Image.Image, config: dict):
"""Process and display inference results"""
def process_and_display_results(output: dict, image: Image.Image, config: dict):
"""Process and display inference results in a way that matches parser's refined logic"""
prompt, response = output['prompt'], output['response']

prompt_key = output.get('prompt_key')

# Determine if this is a layout prompt
layout_prompts = ['prompt_layout_all_en', 'prompt_layout_only_en', 'prompt_grounding_ocr']
is_layout = prompt_key in layout_prompts

# Always show original image
st.markdown('---')
st.markdown('##### Original Image')
st.image(image, width=image.width if image.width < 1000 else 1000)

# Show server input dimensions for reference
input_width, input_height = get_input_dimensions(
image,
min_pixels=config['min_pixels'],
max_pixels=config['max_pixels']
)
st.write(f'Input Dimensions: {input_width} x {input_height}')

if not is_layout:
# Non-layout prompts: show raw text output only
st.text_area('Model Output', response, height=300)
return

# Layout prompts: try to parse JSON and visualize only when there are cells
try:
col1, col2 = st.columns(2)
# st.markdown('---')
cells = json.loads(response)
# image = Image.open(img_url)

# Post-processing
cells = post_process_cells(
image, cells,
image.width, image.height,
min_pixels=config['min_pixels'],
max_pixels=config['max_pixels']
)

# Calculate input dimensions
input_width, input_height = get_input_dimensions(
image,
min_pixels=config['min_pixels'],
max_pixels=config['max_pixels']
)
st.markdown('---')
st.write(f'Input Dimensions: {input_width} x {input_height}')
# st.write(f'Prompt: {prompt}')
# st.markdown(f'模型原始输出: <span style="color:blue">{result}</span>', unsafe_allow_html=True)
# st.write('模型原始输出:')
# st.write(response)
# st.write('后处理结果:', str(cells))
st.text_area('Original Model Output', response, height=200)
st.text_area('Post-processed Result', str(cells), height=200)
# 显示结果
# st.title("Layout推理结果")

with col1:
# st.markdown("##### 可视化结果")
new_image = draw_layout_on_image(
image, cells,
resized_height=None, resized_width=None,
# text_key='text',
fill_bbox=True, draw_bbox=True
)
st.markdown('##### Visualization Result')
st.image(new_image, width=new_image.width)
# st.write(f"尺寸: {new_image.width} x {new_image.height}")

with col2:
# st.markdown("##### Markdown格式")
md_code = layoutjson2md(image, cells, text_key='text')
# md_code = fix_streamlit_formula(md_code)
st.markdown('##### Markdown Format')
st.markdown(md_code, unsafe_allow_html=True)


st.text_area('Original Model Output (JSON)', response, height=200)
st.text_area('Post-processed Cells', str(cells), height=200)

if isinstance(cells, list) and len(cells) > 0:
col1, col2 = st.columns(2)
with col1:
new_image = draw_layout_on_image(
image, cells,
resized_height=None, resized_width=None,
fill_bbox=True, draw_bbox=True
)
st.markdown('##### Layout Visualization')
st.image(new_image, width=new_image.width if new_image.width < 1000 else 1000)
with col2:
md_code = layoutjson2md(image, cells, text_key='text')
st.markdown('##### Markdown Format')
st.markdown(md_code, unsafe_allow_html=True)
else:
st.info('No layout detected. Skipping visualization.')
except json.JSONDecodeError:
st.error("Model output is not a valid JSON format")
# JSON invalid => align with parser: no layout image should be created
st.warning('Model output is not valid JSON for a layout prompt. Skipping visualization.')
st.text_area('Original Model Output', response, height=300)
except Exception as e:
st.error(f"Error processing results: {e}")

Expand Down Expand Up @@ -205,10 +208,10 @@ def main():

response = inference_with_vllm(
processed_image, prompt, config['ip'], config['port'],
# config['min_pixels'], config['max_pixels']
)
output = {
'prompt': prompt,
'prompt_key': config['prompt_key'],
'response': response,
}
else:
Expand Down
31 changes: 15 additions & 16 deletions dots_ocr/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ def _parse_single_image(
}
if source == 'pdf':
save_name = f"{save_name}_page_{page_idx}"
# Always save original (untouched) image
original_image_path = os.path.join(save_dir, f"{save_name}_original.jpg")
origin_image.save(original_image_path)
result.update({'original_image_path': original_image_path})
if prompt_mode in ['prompt_layout_all_en', 'prompt_layout_only_en', 'prompt_grounding_ocr']:
cells, filtered = post_process_output(
response,
Expand All @@ -185,12 +189,8 @@ def _parse_single_image(
json_file_path = os.path.join(save_dir, f"{save_name}.json")
with open(json_file_path, 'w', encoding="utf-8") as w:
json.dump(response, w, ensure_ascii=False)

image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
origin_image.save(image_layout_path)
result.update({
'layout_info_path': json_file_path,
'layout_image_path': image_layout_path,
})

md_file_path = os.path.join(save_dir, f"{save_name}.md")
Expand All @@ -204,21 +204,26 @@ def _parse_single_image(
})
else:
try:
image_with_layout = draw_layout_on_image(origin_image, cells)
if isinstance(cells, list) and len(cells) > 0:
image_with_layout = draw_layout_on_image(origin_image, cells)
else:
image_with_layout = None
except Exception as e:
print(f"Error drawing layout on image: {e}")
image_with_layout = origin_image
image_with_layout = None

json_file_path = os.path.join(save_dir, f"{save_name}.json")
with open(json_file_path, 'w', encoding="utf-8") as w:
json.dump(cells, w, ensure_ascii=False)

image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
image_with_layout.save(image_layout_path)
result.update({
'layout_info_path': json_file_path,
'layout_image_path': image_layout_path,
})
if image_with_layout is not None:
image_layout_path = os.path.join(save_dir, f"{save_name}_layout.jpg")
image_with_layout.save(image_layout_path)
result.update({
'layout_image_path': image_layout_path,
})
if prompt_mode != "prompt_layout_only_en": # no text md when detection only
md_content = layoutjson2md(origin_image, cells, text_key='text')
md_content_no_hf = layoutjson2md(origin_image, cells, text_key='text', no_page_hf=True) # used for clean output or metric of omnidocbench、olmbench
Expand All @@ -233,12 +238,6 @@ def _parse_single_image(
'md_content_nohf_path': md_nohf_file_path,
})
else:
image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
origin_image.save(image_layout_path)
result.update({
'layout_image_path': image_layout_path,
})

md_content = response
md_file_path = os.path.join(save_dir, f"{save_name}.md")
with open(md_file_path, "w", encoding="utf-8") as md_file:
Expand Down