-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsubtitle_extract_simple.py
More file actions
388 lines (315 loc) · 13.9 KB
/
subtitle_extract_simple.py
File metadata and controls
388 lines (315 loc) · 13.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
from tqdm import tqdm
from Levenshtein import ratio
from paddleocr import PaddleOCR
import cv2
import numpy as np
import os
import json
import time
from collections import Counter
def most_frequent_str(lst):
# 统计每个字符串出现的次数
counter = Counter(lst)
# 获取出现次数最多的字符串,按出现次数从高到低排序
most_common = counter.most_common()
# 找到出现次数最多的字符串,若出现次数相同则取最后一个
max_count = most_common[0][1] # 最大的出现次数
result = None
for string, count in most_common:
if count == max_count:
result = string # 如果是最大次数的,更新结果
return result
class PaddleOCRSubtitleExtractor:
def __init__(
self,
frame_nums_per_second=10,
similarity_threshold=0.5, # 相似度阈值
min_duration=0.2, # 最小字幕持续时间(秒)
cache_dir="ocr_cache", # 缓存目录
*args,
**kwargs,
):
self.THRESHOLD_TEXT_SIMILARITY = similarity_threshold
self.MIN_DURATION = min_duration
self.cache_dir = cache_dir # 存储中间结果的目录
os.makedirs(self.cache_dir, exist_ok=True)
self.frame_nums_per_second = frame_nums_per_second
self.init_ocr_model()
def img_scale_rate(
self,
width,
height,
sub_area,
max_width = 1440,
max_height = 1080,
):
# 计算视频缩放因子
scale_x = 1
scale_y = 1
if width > max_width:
scale_x = max_width / width
if height > max_height:
scale_y = max_height / height
# 选择较小的比例进行缩放,确保视频比例保持一致
scale = min(scale_x, scale_y)
# 计算新的宽度和高度
new_width = int(width * scale)
new_height = int(height * scale)
# 根据缩放比例调整子区域
scaled_sub_area = (
int(sub_area[0] * scale), # y_start
int(sub_area[1] * scale), # y_end
int(sub_area[2] * scale), # x_start
int(sub_area[3] * scale) # x_end
)
return new_width, new_height, scaled_sub_area
def init_ocr_model(
self,
*args,
**kwargs,
):
# self.ocr = PaddleOCR(
# use_doc_orientation_classify=False, # 通过 use_doc_orientation_classify 参数指定不使用文档方向分类模型
# use_doc_unwarping=False, # 通过 use_doc_unwarping 参数指定不使用文本图像矫正模型
# use_textline_orientation=False, # 通过 use_textline_orientation 参数指定不使用文本行方向分类模型
# # device="gpu:7"
# )
# ocr = PaddleOCR(lang="en") # 通过 lang 参数来使用英文模型
# ocr = PaddleOCR(ocr_version="PP-OCRv4") # 通过 ocr_version 参数来使用 PP-OCR 其他版本
# ocr = PaddleOCR(device="gpu") # 通过 device 参数使得在模型推理时使用 GPU
self.ocr = PaddleOCR(
text_detection_model_name="PP-OCRv5_server_det",
text_recognition_model_name="PP-OCRv5_server_rec",
use_doc_orientation_classify = False,
use_doc_unwarping = False,
use_textline_orientation = False,
text_det_unclip_ratio=1.2,
text_det_limit_side_len=1080,
text_det_limit_type='max',
text_rec_score_thresh = 0.5,
# device="gpu:7"
) # 更换 PP-OCRv5_server 模型
def ocr_predict(
self,
frame,
):
ocr_res = self.ocr.predict(frame)
dt_box = np.array(ocr_res[0]['dt_polys']).tolist()
rec_res = ocr_res[0]['rec_texts']
return dt_box, rec_res
# return "".join(rec_res), dt_box
def get_dynamic_threshold(self, text1, text2):
"""根据文本长度动态调整相似度阈值"""
base_threshold = self.THRESHOLD_TEXT_SIMILARITY
max_len = max(len(text1), len(text2))
# 文本越短,需要的相似度越高
if max_len <= 3:
return max(0.95, base_threshold + 0.1) # 短文本提高阈值
elif max_len <= 10:
return base_threshold
else:
# 长文本可以适当降低阈值
return max(0.7, base_threshold - 0.05)
def merge_subtitles(self, frame_results):
"""合并相似且连续的重复字幕"""
if not frame_results:
return []
# 第一步:按时间顺序处理所有帧
subtitles = []
current = {
'start': frame_results[0][0],
'end': frame_results[0][0],
'text': "".join(frame_results[0][1]),
'duration': 0
}
temp_txts = []
for i in range(1, len(frame_results)):
timestamp, text_list, dt_box = frame_results[i]
text = "".join(text_list)
# 跳过空文本
if not text.strip():
continue
# 计算动态阈值
dyn_threshold = self.get_dynamic_threshold(current['text'], text)
similarity = ratio(current['text'], text)
# 检查文本相似性
if similarity >= dyn_threshold:
# 文本相似,更新结束时间
current['end'] = timestamp
current['duration'] = current['end'] - current['start']
temp_txts.append(text)
# 如果新文本更长,则更新为更长的文本
current['text'] = most_frequent_str(temp_txts)
else:
# 文本不相似,保存当前字幕
subtitles.append(current)
# 开始新的字幕
current = {
'start': timestamp,
'end': timestamp,
'text': text,
'duration': 0
}
temp_txts = []
# 添加最后一个字幕
subtitles.append(current)
# 第二步:过滤掉持续时间太短的字幕
filtered_subs = [sub for sub in subtitles if sub['duration'] >= self.MIN_DURATION]
# 第三步:合并相邻的相似字幕(考虑时间间隔)
if not filtered_subs:
return []
merged = []
current_sub = filtered_subs[0]
for i in range(1, len(filtered_subs)):
next_sub = filtered_subs[i]
time_gap = next_sub['start'] - current_sub['end']
# 计算动态阈值
dyn_threshold = self.get_dynamic_threshold(current_sub['text'], next_sub['text'])
similarity = ratio(current_sub['text'], next_sub['text'])
# 检查时间连续性和文本相似性
if time_gap < 1.0 and similarity >= dyn_threshold:
# 合并字幕
current_sub['end'] = next_sub['end']
current_sub['duration'] = current_sub['end'] - current_sub['start']
# 选择更长的文本
if len(next_sub['text']) > len(current_sub['text']):
current_sub['text'] = next_sub['text']
else:
merged.append(current_sub)
current_sub = next_sub
# 添加最后一个字幕
merged.append(current_sub)
return merged
def get_cache_filename(self, video_path, sub_area):
"""生成缓存文件名"""
# 包含区域信息
area_str = f"{sub_area[0]}-{sub_area[1]}-{sub_area[2]}-{sub_area[3]}"
# 包含参数信息
params_str = f"{self.frame_nums_per_second}-{self.THRESHOLD_TEXT_SIMILARITY}"
# 生成文件名
base_name = os.path.splitext(os.path.basename(video_path))[0]
return f"{self.cache_dir}/{base_name}_{area_str}_{params_str}.json"
def save_frame_results(self, frame_results, cache_file):
"""保存帧结果到缓存文件"""
# 简化数据结构以便存储
simplified = [{
"time": t,
"text_list": text_list,
"dt_box": dt_box
} for t, text_list, dt_box in frame_results]
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(simplified, f, indent=2, ensure_ascii=False)
print(f"Saved frame results to cache: {cache_file}")
def load_frame_results(self, cache_file):
"""从缓存文件加载帧结果"""
if not os.path.exists(cache_file):
return None
with open(cache_file, 'r', encoding='utf-8') as f:
data = json.load(f)
# 转换回原始格式
return [(item["time"], item["text_list"], item['dt_box']) for item in data]
def extract(
self,
video_path,
sub_area,
output_format='srt', # 支持srt或json
output_dir=None
):
start_time = time.time()
if not os.path.exists(video_path):
raise FileNotFoundError(f"Video file not found: {video_path}")
# 获取缓存文件名
cache_file = self.get_cache_filename(video_path, sub_area)
# 尝试从缓存加载帧结果
frame_results = None
frame_results = self.load_frame_results(cache_file)
if frame_results:
print(f"Loaded frame results from cache: {cache_file}")
current_frame_no = len(frame_results)
# 如果缓存不存在或未使用缓存,则处理视频
if not frame_results:
video_cap = cv2.VideoCapture(video_path)
if not video_cap.isOpened():
raise IOError(f"Could not open video: {video_path}")
fps = video_cap.get(cv2.CAP_PROP_FPS)
frame_count = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# 计算跳帧步长
step = max(1, int(fps / self.frame_nums_per_second))
# 缩放视频帧
width, height, sub_area = self.img_scale_rate(
width, height, sub_area
)
print(f"Processing video: {os.path.basename(video_path)}")
print(f"FPS: {fps:.2f}, Frames: {frame_count}, Step: {step}")
print(f"Subtitle area: {sub_area}")
# 存储所有帧的结果 (timestamp, text)
frame_results = []
pbar = tqdm(total=frame_count, desc="Extracting subtitles", ncols=100)
current_frame_no = 0
while video_cap.isOpened():
ret, frame = video_cap.read()
if not ret:
break
# 跳帧处理
if current_frame_no % step != 0:
current_frame_no += 1
pbar.update(1)
continue
# 计算当前时间戳
timestamp = current_frame_no / fps
# 调整尺寸并裁剪字幕区域
frame = cv2.resize(frame, (width, height))
sub_frame = frame[sub_area[0]:sub_area[1], sub_area[2]:sub_area[3]]
# 提取文本
dt_box, text_list = self.ocr_predict(sub_frame)
# 存储结果
frame_results.append((timestamp, text_list, dt_box))
current_frame_no += 1
pbar.update(1)
# 关闭资源
video_cap.release()
pbar.close()
self.save_frame_results(frame_results, cache_file)
# 处理提取的字幕
merged_subs = self.merge_subtitles(frame_results)
# 输出结果
if not output_dir:
output_dir = os.path.dirname(video_path)
base_name = os.path.splitext(os.path.basename(video_path))[0]
if output_format.lower() == 'json':
output_path = os.path.join(output_dir, f"{base_name}_subtitles.json")
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(merged_subs, f, indent=2, ensure_ascii=False)
print(f"Subtitles saved as JSON: {output_path}")
else: # 默认为SRT格式
output_path = os.path.join(output_dir, f"{base_name}.srt")
self.save_as_srt(merged_subs, output_path)
print(f"Subtitles saved as SRT: {output_path}")
elapsed = time.time() - start_time
print(f"Processed {current_frame_no} frames in {elapsed:.2f} seconds")
print(f"Extracted {len(merged_subs)} subtitles")
return merged_subs
def save_as_srt(self, subtitles, output_path):
"""将字幕保存为SRT格式"""
with open(output_path, 'w', encoding='utf-8') as f:
for i, sub in enumerate(subtitles, 1):
# 格式化时间戳 (HH:MM:SS,ms)
start_time = self.format_timestamp(sub['start'])
end_time = self.format_timestamp(sub['end'])
f.write(f"{i}\n")
f.write(f"{start_time} --> {end_time}\n")
f.write(f"{sub['text']}\n\n")
def format_timestamp(self, seconds):
"""将秒数格式化为SRT时间戳"""
hours, remainder = divmod(seconds, 3600)
minutes, seconds = divmod(remainder, 60)
milliseconds = (seconds - int(seconds)) * 1000
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
if __name__ == '__main__':
extractor = PaddleOCRSubtitleExtractor()
extractor.extract(
video_path='test_cn.mp4',
sub_area=(842, 1069, 72, 1368)
)