[downloader/fragment] Improve --live-from-start for YouTube livestreams (#2870)

This commit is contained in:
Lesmiscore (Naoya Ozaki) 2022-02-25 02:00:46 +09:00 committed by GitHub
parent b440e1bb22
commit a539f06570
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 16 deletions

View File

@ -25,6 +25,7 @@
error_to_compat_str, error_to_compat_str,
encodeFilename, encodeFilename,
sanitized_Request, sanitized_Request,
traverse_obj,
) )
@ -382,6 +383,7 @@ def download_and_append_fragments_multiple(self, *args, pack_func=None, finish_f
max_workers = self.params.get('concurrent_fragment_downloads', 1) max_workers = self.params.get('concurrent_fragment_downloads', 1)
if max_progress > 1: if max_progress > 1:
self._prepare_multiline_status(max_progress) self._prepare_multiline_status(max_progress)
is_live = any(traverse_obj(args, (..., 2, 'is_live'), default=[]))
def thread_func(idx, ctx, fragments, info_dict, tpe): def thread_func(idx, ctx, fragments, info_dict, tpe):
ctx['max_progress'] = max_progress ctx['max_progress'] = max_progress
@ -395,25 +397,44 @@ class FTPE(concurrent.futures.ThreadPoolExecutor):
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
pass pass
spins = []
if compat_os_name == 'nt': if compat_os_name == 'nt':
self.report_warning('Ctrl+C does not work on Windows when used with parallel threads. ' def bindoj_result(future):
'This is a known issue and patches are welcome') while True:
try:
return future.result(0.1)
except KeyboardInterrupt:
raise
except concurrent.futures.TimeoutError:
continue
else:
def bindoj_result(future):
return future.result()
spins = []
for idx, (ctx, fragments, info_dict) in enumerate(args): for idx, (ctx, fragments, info_dict) in enumerate(args):
tpe = FTPE(math.ceil(max_workers / max_progress)) tpe = FTPE(math.ceil(max_workers / max_progress))
job = tpe.submit(thread_func, idx, ctx, fragments, info_dict, tpe)
def interrupt_trigger_iter():
for f in fragments:
if not interrupt_trigger[0]:
break
yield f
job = tpe.submit(thread_func, idx, ctx, interrupt_trigger_iter(), info_dict, tpe)
spins.append((tpe, job)) spins.append((tpe, job))
result = True result = True
for tpe, job in spins: for tpe, job in spins:
try: try:
result = result and job.result() result = result and bindoj_result(job)
except KeyboardInterrupt: except KeyboardInterrupt:
interrupt_trigger[0] = False interrupt_trigger[0] = False
finally: finally:
tpe.shutdown(wait=True) tpe.shutdown(wait=True)
if not interrupt_trigger[0]: if not interrupt_trigger[0] and not is_live:
raise KeyboardInterrupt() raise KeyboardInterrupt()
# we expect the user wants to stop and DO WANT the preceding postprocessors to run;
# so returning a intermediate result here instead of KeyboardInterrupt on live
return result return result
def download_and_append_fragments( def download_and_append_fragments(
@ -431,10 +452,11 @@ def download_and_append_fragments(
pack_func = lambda frag_content, _: frag_content pack_func = lambda frag_content, _: frag_content
def download_fragment(fragment, ctx): def download_fragment(fragment, ctx):
if not interrupt_trigger[0]:
return False, fragment['frag_index']
frag_index = ctx['fragment_index'] = fragment['frag_index'] frag_index = ctx['fragment_index'] = fragment['frag_index']
ctx['last_error'] = None ctx['last_error'] = None
if not interrupt_trigger[0]:
return False, frag_index
headers = info_dict.get('http_headers', {}).copy() headers = info_dict.get('http_headers', {}).copy()
byte_range = fragment.get('byte_range') byte_range = fragment.get('byte_range')
if byte_range: if byte_range:
@ -500,8 +522,6 @@ def _download_fragment(fragment):
self.report_warning('The download speed shown is only of one thread. This is a known issue and patches are welcome') self.report_warning('The download speed shown is only of one thread. This is a known issue and patches are welcome')
with tpe or concurrent.futures.ThreadPoolExecutor(max_workers) as pool: with tpe or concurrent.futures.ThreadPoolExecutor(max_workers) as pool:
for fragment, frag_content, frag_index, frag_filename in pool.map(_download_fragment, fragments): for fragment, frag_content, frag_index, frag_filename in pool.map(_download_fragment, fragments):
if not interrupt_trigger[0]:
break
ctx['fragment_filename_sanitized'] = frag_filename ctx['fragment_filename_sanitized'] = frag_filename
ctx['fragment_index'] = frag_index ctx['fragment_index'] = frag_index
result = append_fragment(decrypt_fragment(fragment, frag_content), frag_index, ctx) result = append_fragment(decrypt_fragment(fragment, frag_content), frag_index, ctx)

View File

@ -2135,6 +2135,7 @@ def mpd_feed(format_id, delay):
return f['manifest_url'], f['manifest_stream_number'], is_live return f['manifest_url'], f['manifest_stream_number'], is_live
for f in formats: for f in formats:
f['is_live'] = True
f['protocol'] = 'http_dash_segments_generator' f['protocol'] = 'http_dash_segments_generator'
f['fragments'] = functools.partial( f['fragments'] = functools.partial(
self._live_dash_fragments, f['format_id'], live_start_time, mpd_feed) self._live_dash_fragments, f['format_id'], live_start_time, mpd_feed)
@ -2157,12 +2158,12 @@ def _live_dash_fragments(self, format_id, live_start_time, mpd_feed, ctx):
known_idx, no_fragment_score, last_segment_url = begin_index, 0, None known_idx, no_fragment_score, last_segment_url = begin_index, 0, None
fragments, fragment_base_url = None, None fragments, fragment_base_url = None, None
def _extract_sequence_from_mpd(refresh_sequence): def _extract_sequence_from_mpd(refresh_sequence, immediate):
nonlocal mpd_url, stream_number, is_live, no_fragment_score, fragments, fragment_base_url nonlocal mpd_url, stream_number, is_live, no_fragment_score, fragments, fragment_base_url
# Obtain from MPD's maximum seq value # Obtain from MPD's maximum seq value
old_mpd_url = mpd_url old_mpd_url = mpd_url
last_error = ctx.pop('last_error', None) last_error = ctx.pop('last_error', None)
expire_fast = last_error and isinstance(last_error, compat_HTTPError) and last_error.code == 403 expire_fast = immediate or last_error and isinstance(last_error, compat_HTTPError) and last_error.code == 403
mpd_url, stream_number, is_live = (mpd_feed(format_id, 5 if expire_fast else 18000) mpd_url, stream_number, is_live = (mpd_feed(format_id, 5 if expire_fast else 18000)
or (mpd_url, stream_number, False)) or (mpd_url, stream_number, False))
if not refresh_sequence: if not refresh_sequence:
@ -2176,7 +2177,7 @@ def _extract_sequence_from_mpd(refresh_sequence):
except ExtractorError: except ExtractorError:
fmts = None fmts = None
if not fmts: if not fmts:
no_fragment_score += 1 no_fragment_score += 2
return False, last_seq return False, last_seq
fmt_info = next(x for x in fmts if x['manifest_stream_number'] == stream_number) fmt_info = next(x for x in fmts if x['manifest_stream_number'] == stream_number)
fragments = fmt_info['fragments'] fragments = fmt_info['fragments']
@ -2199,11 +2200,12 @@ def _extract_sequence_from_mpd(refresh_sequence):
urlh = None urlh = None
last_seq = try_get(urlh, lambda x: int_or_none(x.headers['X-Head-Seqnum'])) last_seq = try_get(urlh, lambda x: int_or_none(x.headers['X-Head-Seqnum']))
if last_seq is None: if last_seq is None:
no_fragment_score += 1 no_fragment_score += 2
last_segment_url = None last_segment_url = None
continue continue
else: else:
should_continue, last_seq = _extract_sequence_from_mpd(True) should_continue, last_seq = _extract_sequence_from_mpd(True, no_fragment_score > 15)
no_fragment_score += 2
if not should_continue: if not should_continue:
continue continue
@ -2221,7 +2223,7 @@ def _extract_sequence_from_mpd(refresh_sequence):
try: try:
for idx in range(known_idx, last_seq): for idx in range(known_idx, last_seq):
# do not update sequence here or you'll get skipped some part of it # do not update sequence here or you'll get skipped some part of it
should_continue, _ = _extract_sequence_from_mpd(False) should_continue, _ = _extract_sequence_from_mpd(False, False)
if not should_continue: if not should_continue:
known_idx = idx - 1 known_idx = idx - 1
raise ExtractorError('breaking out of outer loop') raise ExtractorError('breaking out of outer loop')