mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2024-11-18 18:59:32 +01:00
Fixup frame timestamp in MP4 file without ffmpeg
This commit is contained in:
parent
56ba69e4c9
commit
d4c52a28af
51
test/test_mp4parser.py
Normal file
51
test/test_mp4parser.py
Normal file
@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Allow direct execution
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import io
|
||||
|
||||
from yt_dlp.mp4_parser import (
|
||||
parse_mp4_boxes,
|
||||
write_mp4_boxes,
|
||||
)
|
||||
|
||||
TEST_SEQUENCE = [
|
||||
('test', b'123456'),
|
||||
('trak', b''),
|
||||
('helo', b'abcdef'),
|
||||
('1984', b'1q84'),
|
||||
('moov', b''),
|
||||
('keys', b'2022'),
|
||||
(None, 'moov'),
|
||||
('topp', b'1991'),
|
||||
(None, 'trak'),
|
||||
]
|
||||
|
||||
# on-file reprensetation of the above sequence
|
||||
TEST_BYTES = b'\x00\x00\x00\x0etest123456\x00\x00\x00Btrak\x00\x00\x00\x0eheloabcdef\x00\x00\x00\x0c19841q84\x00\x00\x00\x14moov\x00\x00\x00\x0ckeys2022\x00\x00\x00\x0ctopp1991'
|
||||
|
||||
|
||||
class TestMP4Parser(unittest.TestCase):
|
||||
def test_write_sequence(self):
|
||||
with io.BytesIO() as w:
|
||||
write_mp4_boxes(w, TEST_SEQUENCE)
|
||||
bs = w.getvalue()
|
||||
self.assertEqual(TEST_BYTES, bs)
|
||||
|
||||
def test_read_bytes(self):
|
||||
with io.BytesIO(TEST_BYTES) as r:
|
||||
result = list(parse_mp4_boxes(r))
|
||||
self.assertListEqual(TEST_SEQUENCE, result)
|
||||
|
||||
def test_mismatched_box_end(self):
|
||||
with io.BytesIO() as w, self.assertRaises(AssertionError):
|
||||
write_mp4_boxes(w, [
|
||||
('moov', b''),
|
||||
('trak', b''),
|
||||
(None, 'moov'),
|
||||
(None, 'trak'),
|
||||
])
|
@ -55,6 +55,7 @@
|
||||
FFmpegMergerPP,
|
||||
FFmpegPostProcessor,
|
||||
MoveFilesAfterDownloadPP,
|
||||
MP4FixupTimestampPP,
|
||||
get_postprocessor,
|
||||
)
|
||||
from .update import detect_variant
|
||||
@ -3256,8 +3257,11 @@ def ffmpeg_fixup(cndn, msg, cls):
|
||||
ffmpeg_fixup(info_dict.get('is_live') and downloader == 'DashSegmentsFD',
|
||||
'Possible duplicate MOOV atoms', FFmpegFixupDuplicateMoovPP)
|
||||
|
||||
is_fmp4 = info_dict.get('protocol') == 'websocket_frag' and info_dict.get('container') == 'fmp4'
|
||||
ffmpeg_fixup(downloader == 'web_socket_fragment', 'Malformed timestamps detected', FFmpegFixupTimestampPP)
|
||||
ffmpeg_fixup(downloader == 'web_socket_fragment', 'Malformed duration detected', FFmpegFixupDurationPP)
|
||||
ffmpeg_fixup(downloader == 'web_socket_to_file' and is_fmp4, 'Malformed timestamps detected', MP4FixupTimestampPP)
|
||||
ffmpeg_fixup(downloader == 'web_socket_to_file' and is_fmp4, 'Possible duplicate MOOV atoms', FFmpegFixupDuplicateMoovPP)
|
||||
|
||||
fixup()
|
||||
try:
|
||||
|
@ -33,7 +33,7 @@ def get_suitable_downloader(info_dict, params={}, default=NO_DEFAULT, protocol=N
|
||||
from .niconico import NiconicoDmcFD
|
||||
from .rtmp import RtmpFD
|
||||
from .rtsp import RtspFD
|
||||
from .websocket import WebSocketFragmentFD
|
||||
from .websocket import WebSocketFragmentFD, WebSocketToFileFD
|
||||
from .youtube_live_chat import YoutubeLiveChatFD
|
||||
|
||||
PROTOCOL_MAP = {
|
||||
@ -118,6 +118,9 @@ def _get_suitable_downloader(info_dict, protocol, params, default):
|
||||
elif params.get('hls_prefer_native') is False:
|
||||
return FFmpegFD
|
||||
|
||||
if protocol == 'websocket_frag' and info_dict.get('container') == 'fmp4' and external_downloader != 'ffmpeg':
|
||||
return WebSocketToFileFD
|
||||
|
||||
return PROTOCOL_MAP.get(protocol, default)
|
||||
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
|
||||
from .common import FileDownloader
|
||||
from .external import FFmpegFD
|
||||
@ -9,23 +8,29 @@
|
||||
from ..dependencies import websockets
|
||||
|
||||
|
||||
class FFmpegSinkFD(FileDownloader):
|
||||
class AsyncSinkFD(FileDownloader):
|
||||
async def connect(self, stdin, info_dict):
|
||||
try:
|
||||
await self.real_connection(stdin, info_dict)
|
||||
except OSError:
|
||||
pass
|
||||
finally:
|
||||
with contextlib.suppress(OSError):
|
||||
stdin.flush()
|
||||
stdin.close()
|
||||
|
||||
async def real_connection(self, sink, info_dict):
|
||||
""" Override this in subclasses """
|
||||
raise NotImplementedError('This method must be implemented by subclasses')
|
||||
|
||||
|
||||
class FFmpegSinkFD(AsyncSinkFD):
|
||||
""" A sink to ffmpeg for downloading fragments in any form """
|
||||
|
||||
def real_download(self, filename, info_dict):
|
||||
info_copy = info_dict.copy()
|
||||
info_copy['url'] = '-'
|
||||
|
||||
async def call_conn(proc, stdin):
|
||||
try:
|
||||
await self.real_connection(stdin, info_dict)
|
||||
except OSError:
|
||||
pass
|
||||
finally:
|
||||
with contextlib.suppress(OSError):
|
||||
stdin.flush()
|
||||
stdin.close()
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
connect = self.connect
|
||||
|
||||
class FFmpegStdinFD(FFmpegFD):
|
||||
@classmethod
|
||||
@ -33,17 +38,57 @@ def get_basename(cls):
|
||||
return FFmpegFD.get_basename()
|
||||
|
||||
def on_process_started(self, proc, stdin):
|
||||
thread = threading.Thread(target=asyncio.run, daemon=True, args=(call_conn(proc, stdin), ))
|
||||
thread = threading.Thread(target=asyncio.run, daemon=True, args=(connect(stdin, info_dict), ))
|
||||
thread.start()
|
||||
|
||||
return FFmpegStdinFD(self.ydl, self.params or {}).download(filename, info_copy)
|
||||
|
||||
async def real_connection(self, sink, info_dict):
|
||||
""" Override this in subclasses """
|
||||
raise NotImplementedError('This method must be implemented by subclasses')
|
||||
|
||||
class FileSinkFD(AsyncSinkFD):
|
||||
""" A sink to a file for downloading fragments in any form """
|
||||
def real_download(self, filename, info_dict):
|
||||
tempname = self.temp_name(filename)
|
||||
try:
|
||||
with open(tempname, 'wb') as w:
|
||||
started = time.time()
|
||||
status = {
|
||||
'filename': info_dict.get('_filename'),
|
||||
'status': 'downloading',
|
||||
'elapsed': 0,
|
||||
'downloaded_bytes': 0,
|
||||
}
|
||||
self._hook_progress(status, info_dict)
|
||||
|
||||
thread = threading.Thread(target=asyncio.run, daemon=True, args=(self.connect(w, info_dict), ))
|
||||
thread.start()
|
||||
time_and_size, avg_len = [], 10
|
||||
while thread.is_alive():
|
||||
time.sleep(0.1)
|
||||
|
||||
downloaded, curr = w.tell(), time.time()
|
||||
# taken from ffmpeg attachment
|
||||
time_and_size.append((downloaded, curr))
|
||||
time_and_size = time_and_size[-avg_len:]
|
||||
if len(time_and_size) > 1:
|
||||
last, early = time_and_size[0], time_and_size[-1]
|
||||
average_speed = (early[0] - last[0]) / (early[1] - last[1])
|
||||
else:
|
||||
average_speed = None
|
||||
|
||||
status.update({
|
||||
'downloaded_bytes': downloaded,
|
||||
'speed': average_speed,
|
||||
'elapsed': curr - started,
|
||||
})
|
||||
self._hook_progress(status, info_dict)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
self.ydl.replace(tempname, filename)
|
||||
return True
|
||||
|
||||
|
||||
class WebSocketFragmentFD(FFmpegSinkFD):
|
||||
class _WebSocketFD(AsyncSinkFD):
|
||||
async def real_connection(self, sink, info_dict):
|
||||
async with websockets.connect(info_dict['url'], extra_headers=info_dict.get('http_headers', {})) as ws:
|
||||
while True:
|
||||
@ -51,3 +96,11 @@ async def real_connection(self, sink, info_dict):
|
||||
if isinstance(recv, str):
|
||||
recv = recv.encode('utf8')
|
||||
sink.write(recv)
|
||||
|
||||
|
||||
class WebSocketFragmentFD(_WebSocketFD, FFmpegSinkFD):
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketToFileFD(_WebSocketFD, FileSinkFD):
|
||||
pass
|
||||
|
@ -173,6 +173,7 @@ def find_dmu(x):
|
||||
'source_preference': -10,
|
||||
# TwitCasting simply sends moof atom directly over WS
|
||||
'protocol': 'websocket_frag',
|
||||
'container': 'fmp4',
|
||||
})
|
||||
|
||||
self._sort_formats(formats, ('source',))
|
||||
|
136
yt_dlp/mp4_parser.py
Normal file
136
yt_dlp/mp4_parser.py
Normal file
@ -0,0 +1,136 @@
|
||||
import struct
|
||||
|
||||
from typing import Tuple
|
||||
from io import BytesIO, RawIOBase
|
||||
|
||||
|
||||
class LengthLimiter(RawIOBase):
|
||||
def __init__(self, r: RawIOBase, size: int):
|
||||
self.r = r
|
||||
self.remaining = size
|
||||
|
||||
def read(self, sz: int = None) -> bytes:
|
||||
if self.remaining == 0:
|
||||
return b''
|
||||
if sz in (-1, None):
|
||||
sz = self.remaining
|
||||
sz = min(sz, self.remaining)
|
||||
ret = self.r.read(sz)
|
||||
if ret:
|
||||
self.remaining -= len(ret)
|
||||
return ret
|
||||
|
||||
def readall(self) -> bytes:
|
||||
if self.remaining == 0:
|
||||
return b''
|
||||
ret = self.read(self.remaining)
|
||||
if ret:
|
||||
self.remaining -= len(ret)
|
||||
return ret
|
||||
|
||||
def readable(self) -> bool:
|
||||
return bool(self.remaining)
|
||||
|
||||
|
||||
def read_harder(r, size):
|
||||
retry = 0
|
||||
buf = b''
|
||||
while len(buf) < size and retry < 3:
|
||||
ret = r.read(size - len(buf))
|
||||
if not ret:
|
||||
retry += 1
|
||||
continue
|
||||
retry = 0
|
||||
buf += ret
|
||||
|
||||
return buf
|
||||
|
||||
|
||||
def pack_be32(value: int) -> bytes:
|
||||
return struct.pack('>I', value)
|
||||
|
||||
|
||||
def pack_be64(value: int) -> bytes:
|
||||
return struct.pack('>L', value)
|
||||
|
||||
|
||||
def unpack_be32(value: bytes) -> int:
|
||||
return struct.unpack('>I', value)[0]
|
||||
|
||||
|
||||
def unpack_ver_flags(value: bytes) -> Tuple[int, int]:
|
||||
ver, up_flag, down_flag = struct.unpack('>BBH', value)
|
||||
return ver, (up_flag << 16 | down_flag)
|
||||
|
||||
|
||||
def unpack_be64(value: bytes) -> int:
|
||||
return struct.unpack('>L', value)[0]
|
||||
|
||||
|
||||
# https://github.com/gpac/mp4box.js/blob/4e1bc23724d2603754971abc00c2bd5aede7be60/src/box.js#L13-L40
|
||||
MP4_CONTAINER_BOXES = ('moov', 'trak', 'edts', 'mdia', 'minf', 'dinf', 'stbl', 'mvex', 'moof', 'traf', 'vttc', 'tref', 'iref', 'mfra', 'meco', 'hnti', 'hinf', 'strk', 'strd', 'sinf', 'rinf', 'schi', 'trgr', 'udta', 'iprp', 'ipco')
|
||||
|
||||
|
||||
def parse_mp4_boxes(r: RawIOBase):
|
||||
"""
|
||||
Parses an ISO BMFF (which MP4 follows) and yields its boxes as a sequence.
|
||||
This does not interpret content of these boxes.
|
||||
|
||||
Sequence details:
|
||||
('atom', b'blablabla'): A box, with content (not container boxes)
|
||||
('atom', b''): Possibly container box (must check MP4_CONTAINER_BOXES) or really an empty box
|
||||
(None, 'atom'): End of a container box
|
||||
|
||||
Example: Path:
|
||||
('test', b'123456') /test
|
||||
('box1', b'') /box1 (start of container box)
|
||||
('helo', b'abcdef') /box1/helo
|
||||
('1984', b'1q84') /box1/1984
|
||||
('http', b'') /box1/http (start of container box)
|
||||
('keys', b'2022') /box1/http/keys
|
||||
(None , 'http') /box1/http (end of container box)
|
||||
('topp', b'1991') /box1/topp
|
||||
(None , 'box1') /box1 (end of container box)
|
||||
"""
|
||||
|
||||
while True:
|
||||
size_b = read_harder(r, 4)
|
||||
if not size_b:
|
||||
break
|
||||
type_b = r.read(4)
|
||||
# 00 00 00 20 is big-endian
|
||||
box_size = unpack_be32(size_b)
|
||||
type_s = type_b.decode()
|
||||
if type_s in MP4_CONTAINER_BOXES:
|
||||
yield (type_s, b'')
|
||||
yield from parse_mp4_boxes(LengthLimiter(r, box_size - 8))
|
||||
yield (None, type_s)
|
||||
continue
|
||||
# subtract by 8
|
||||
full_body = read_harder(r, box_size - 8)
|
||||
yield (type_s, full_body)
|
||||
|
||||
|
||||
def write_mp4_boxes(w: RawIOBase, box_iter):
|
||||
"""
|
||||
Writes an ISO BMFF file from a given sequence to a given writer.
|
||||
The iterator to be passed must follow parse_mp4_boxes's protocol.
|
||||
"""
|
||||
|
||||
stack = [
|
||||
(None, w), # parent box, IO
|
||||
]
|
||||
for btype, content in box_iter:
|
||||
if btype in MP4_CONTAINER_BOXES:
|
||||
bio = BytesIO()
|
||||
stack.append((btype, bio))
|
||||
continue
|
||||
elif btype is None:
|
||||
assert stack[-1][0] == content
|
||||
btype, bio = stack.pop()
|
||||
content = bio.getvalue()
|
||||
|
||||
wt = stack[-1][1]
|
||||
wt.write(pack_be32(len(content) + 8))
|
||||
wt.write(btype.encode()[:4])
|
||||
wt.write(content)
|
@ -30,6 +30,7 @@
|
||||
)
|
||||
from .modify_chapters import ModifyChaptersPP
|
||||
from .movefilesafterdownload import MoveFilesAfterDownloadPP
|
||||
from .mp4direct import MP4FixupTimestampPP
|
||||
from .sponskrub import SponSkrubPP
|
||||
from .sponsorblock import SponsorBlockPP
|
||||
from .xattrpp import XAttrMetadataPP
|
||||
|
126
yt_dlp/postprocessor/mp4direct.py
Normal file
126
yt_dlp/postprocessor/mp4direct.py
Normal file
@ -0,0 +1,126 @@
|
||||
from .common import PostProcessor
|
||||
from ..utils import prepend_extension
|
||||
|
||||
from ..mp4_parser import (
|
||||
write_mp4_boxes,
|
||||
parse_mp4_boxes,
|
||||
pack_be32,
|
||||
pack_be64,
|
||||
unpack_ver_flags,
|
||||
unpack_be32,
|
||||
unpack_be64,
|
||||
)
|
||||
|
||||
|
||||
class MP4FixupTimestampPP(PostProcessor):
|
||||
|
||||
@property
|
||||
def available(self):
|
||||
return True
|
||||
|
||||
def analyze_mp4(self, filepath):
|
||||
""" returns (baseMediaDecodeTime offset, sample duration cutoff) """
|
||||
smallest_bmdt, known_sdur = float('inf'), set()
|
||||
with open(filepath, 'rb') as r:
|
||||
for btype, content in parse_mp4_boxes(r):
|
||||
if btype == 'tfdt':
|
||||
version, _ = unpack_ver_flags(content[0:4])
|
||||
# baseMediaDecodeTime always comes to the first
|
||||
if version == 0:
|
||||
bmdt = unpack_be32(content[4:8])
|
||||
else:
|
||||
bmdt = unpack_be64(content[4:12])
|
||||
if bmdt == 0:
|
||||
continue
|
||||
smallest_bmdt = min(bmdt, smallest_bmdt)
|
||||
elif btype == 'tfhd':
|
||||
version, flags = unpack_ver_flags(content[0:4])
|
||||
if not flags & 0x08:
|
||||
# this box does not contain "sample duration"
|
||||
continue
|
||||
# https://github.com/gpac/mp4box.js/blob/4e1bc23724d2603754971abc00c2bd5aede7be60/src/box.js#L203-L209
|
||||
# https://github.com/gpac/mp4box.js/blob/4e1bc23724d2603754971abc00c2bd5aede7be60/src/parsing/tfhd.js
|
||||
sdur_start = 8 # header + track id
|
||||
if flags & 0x01:
|
||||
sdur_start += 8
|
||||
if flags & 0x02:
|
||||
sdur_start += 4
|
||||
# the next 4 bytes are "sample duration"
|
||||
sample_dur = unpack_be32(content[sdur_start:sdur_start + 4])
|
||||
known_sdur.add(sample_dur)
|
||||
|
||||
maximum_sdur = max(known_sdur)
|
||||
for multiplier in (0.7, 0.8, 0.9, 0.95):
|
||||
sdur_cutoff = maximum_sdur * multiplier
|
||||
if len(set(x for x in known_sdur if x > sdur_cutoff)) < 3:
|
||||
break
|
||||
else:
|
||||
sdur_cutoff = float('inf')
|
||||
|
||||
return smallest_bmdt, sdur_cutoff
|
||||
|
||||
def modify_mp4(self, src, dst, bmdt_offset, sdur_cutoff):
|
||||
with open(src, 'rb') as r, open(dst, 'wb') as w:
|
||||
def converter():
|
||||
for btype, content in parse_mp4_boxes(r):
|
||||
if btype == 'tfdt':
|
||||
version, _ = unpack_ver_flags(content[0:4])
|
||||
if version == 0:
|
||||
bmdt = unpack_be32(content[4:8])
|
||||
else:
|
||||
bmdt = unpack_be64(content[4:12])
|
||||
if bmdt == 0:
|
||||
yield (btype, content)
|
||||
continue
|
||||
# calculate new baseMediaDecodeTime
|
||||
bmdt = max(0, bmdt - bmdt_offset)
|
||||
# pack everything again and insert as a new box
|
||||
if version == 0:
|
||||
bmdt_b = pack_be32(bmdt)
|
||||
else:
|
||||
bmdt_b = pack_be64(bmdt)
|
||||
yield ('tfdt', content[0:4] + bmdt_b + content[8 + version * 4:])
|
||||
continue
|
||||
elif btype == 'tfhd':
|
||||
version, flags = unpack_ver_flags(content[0:4])
|
||||
if not flags & 0x08:
|
||||
yield (btype, content)
|
||||
continue
|
||||
sdur_start = 8
|
||||
if flags & 0x01:
|
||||
sdur_start += 8
|
||||
if flags & 0x02:
|
||||
sdur_start += 4
|
||||
sample_dur = unpack_be32(content[sdur_start:sdur_start + 4])
|
||||
if sample_dur > sdur_cutoff:
|
||||
sample_dur = 0
|
||||
sd_b = pack_be32(sample_dur)
|
||||
yield ('tfhd', content[:sdur_start] + sd_b + content[sdur_start + 4:])
|
||||
continue
|
||||
yield (btype, content)
|
||||
|
||||
write_mp4_boxes(w, converter())
|
||||
|
||||
def run(self, information):
|
||||
filename = information['filepath']
|
||||
temp_filename = prepend_extension(filename, 'temp')
|
||||
|
||||
self.write_debug('Analyzing MP4')
|
||||
bmdt_offset, sdur_cutoff = self.analyze_mp4(filename)
|
||||
working = float('inf') not in (bmdt_offset, sdur_cutoff)
|
||||
# if any of them are Infinity, there's something wrong
|
||||
# baseMediaDecodeTime = to shift PTS
|
||||
# sample duration = to define duration in each segment
|
||||
self.write_debug(f'baseMediaDecodeTime offset = {bmdt_offset}, sample duration cutoff = {sdur_cutoff}')
|
||||
if bmdt_offset == float('inf'):
|
||||
# safeguard
|
||||
bmdt_offset = 0
|
||||
self.modify_mp4(filename, temp_filename, bmdt_offset, sdur_cutoff)
|
||||
if working:
|
||||
self.to_screen('Duration of the file has been fixed')
|
||||
else:
|
||||
self.report_warning(f'Failed to fix duration of the file. (baseMediaDecodeTime offset = {bmdt_offset}, sample duration cutoff = {sdur_cutoff})')
|
||||
|
||||
self._downloader.replace(temp_filename, filename)
|
||||
|
||||
return [], information
|
Loading…
Reference in New Issue
Block a user