"""Decode a video stream with libav."""
import math
import pathlib
import queue
import threading
from fractions import Fraction
import av
from cutcutcodec.core.collections.sa import SliceAccessor
from cutcutcodec.core.exceptions import DecodeError, MissingInformation, OutOfTimeRange
from cutcutcodec.core.signal.predict import LinearPredictor
PRED_BUFFER: int = 16 # number of frames predecited in advance
JUMP: Fraction = Fraction(30) # only seek forward if the jump is bigger than this time delta
[docs]
class FrameAnticipator(threading.Thread):
"""Accurate and predictive reader of raw frames."""
def __init__(self, filename: pathlib.Path, idx: int, av_kwargs: dict[str]):
"""Initialise the Anticipator.
Parameters
----------
filename : pathlib.Path
The path of the video file.
idx : int
The absolute index of the video stream, including the position of all the other streams.
av_kwargs : dict
Transmitted to ``av.open``.
"""
assert isinstance(filename, pathlib.Path), filename.__class__.__name__
assert filename.is_file(), filename
assert isinstance(idx, int), idx.__class__.__name__
assert idx >= 0, idx
assert isinstance(av_kwargs, dict), av_kwargs.__class__.__name__
assert all(isinstance(k, str) for k in av_kwargs), av_kwargs
self._filename = filename
self._idx = idx
self._av_kwargs = av_kwargs
self._frames = SliceAccessor() # to each interval, associate the av frame
self._queues = {
"request": queue.Queue(), # just a timestamp
"result": queue.Queue(), # (timestamp, err_or_frame)
}
self._stop_flag: bool = False # when True, kill the thread
self._av_objects = {
"frame_iter": None,
"curr_frame": None,
"next_frame": None,
}
self._predictor = LinearPredictor(memory=6) # 6 is arbitrary
self._sup_duration = math.inf
self._seek_table: dict[tuple[Fraction, Fraction], Fraction] = {} # (curr, target) -> obs
super().__init__(daemon=True)
self.start()
@staticmethod
def _frame_timestamp(frame: av.video.VideoFrame) -> Fraction:
"""Return the display time of the video frame."""
if (time_base := frame.time_base) is None:
if (timestamp := frame.time) is None:
raise MissingInformation(f"unable to catch the time of the frame {frame}")
return Fraction(timestamp)
if (pts := frame.pts) is not None:
return pts * time_base
if (dts := frame.dts) is not None:
return dts * time_base
raise MissingInformation(f"unable to catch the time of the frame {frame}")
def _next_timestamp(self) -> Fraction:
"""Return the next timestamp to be decoded."""
while True:
for timestamp in sorted(set(map(Fraction, self._predictor.predict(PRED_BUFFER)))):
if self._stop_flag:
return StopIteration
try:
yield self._queues["request"].get_nowait()
break # if a new timestamp in the predictor, reset the prediction
except queue.Empty:
if (0 <= timestamp < self._sup_duration) and timestamp not in self._frames:
# print(f"predict {round(float(timestamp), 3)}")
yield timestamp
else:
try: # timeout a short as possible, to make cpu of self._predictor.predict small
yield self._queues["request"].get(timeout=0.100)
except queue.Empty:
pass
def _seek(self, av_container: av.container.InputContainer, timestamp: Fraction) -> bool:
"""Try to move before the required position.
Returns
-------
seek_ok : boolean
True if we are not after, False if something wrong.
Notes
-----
This method has to be called from the thread (not from the main thread).
"""
b_min, curr_timestamp = self.get_current_range()
if (
timestamp > curr_timestamp + JUMP # if jump more than JUMP seconds
or timestamp < b_min # if we need go backward
):
# adapt target timestamp according to historical observations
target = timestamp
while self._seek_table.get((b_min, target), 0) > timestamp:
if (target := min(target, target-JUMP)) == 0:
return timestamp >= b_min # reset if backward
# try to seek
stream = av_container.streams[self._idx]
if (time_base := stream.time_base) is not None:
try:
av_container.seek( # very approximative
int(target / time_base),
backward=True,
any_frame=False,
stream=stream,
)
except av.error.PermissionError: # happens sometimes
self._seek_table[(b_min, target)] = Fraction(0)
return False # reset
self._av_objects["curr_frame"] # to apply the seek effect
# verification we are before
self._seek_table[(b_min, target)] = self.get_current_range()[0]
return timestamp >= self._seek_table[(b_min, target)]
return True
@property
def curr_frame(self) -> av.video.VideoFrame:
"""Return the frame at the current position."""
if self._av_objects["curr_frame"] is None:
if (frame_iter := self._av_objects["frame_iter"]) is None:
raise RuntimeError("call .start() before .curr_frame")
try:
self._av_objects["curr_frame"] = next(frame_iter)
except (StopIteration, av.error.EOFError) as err:
msg = "there is no frame left to read"
raise OutOfTimeRange(msg) from err
return self._av_objects["curr_frame"]
[docs]
def get_current_range(self) -> tuple[Fraction, Fraction]:
"""Return the time interval cover by the current frame."""
curr_frame = self.curr_frame
start_time = self._frame_timestamp(curr_frame)
if (next_frame := self.next_frame) is None:
if curr_frame.duration is None or curr_frame.time_base is None:
raise MissingInformation(f"failed get the duration of the frame {curr_frame}")
return start_time, start_time + curr_frame.duration * curr_frame.time_base
return start_time, self._frame_timestamp(next_frame)
@property
def next_frame(self) -> av.video.VideoFrame | None:
"""Return the next frame if exists, None else."""
if self._av_objects["next_frame"] is None:
_ = self.curr_frame
try:
self._av_objects["next_frame"] = next(self._av_objects["frame_iter"])
except (StopIteration, av.error.EOFError):
self._av_objects["next_frame"] = None
return self._av_objects["next_frame"]
[docs]
def run(self):
"""Decode the frames in a separate thread."""
while not self._stop_flag:
# open (or reset) the stream
with av.open(str(self._filename), "r", **self._av_kwargs) as av_container:
av_stream = av_container.streams[self._idx]
if av_stream.type != "video":
self._stop_flag = True
self._queues["result"].put(
(
None,
DecodeError(
f"the stream {self._idx} of {self._filename} has to be a video stream, "
f"not {av_stream.type}",
),
),
)
break
self._av_objects["frame_iter"] = iter(av_container.decode(av_stream))
self._av_objects["curr_frame"] = self._av_objects["next_frame"] = None
# get next timestamp
for timestamp in self._next_timestamp():
# seek if needed
if not self._seek(av_container, timestamp): # guarante to be not after
self._queues["request"].put(timestamp) # to avoid infinite loop
break
# decode until reaching the correct frame
# in practice, iterations are rare thanks to the prediction of future timestamps
try:
while timestamp >= self.get_current_range()[1]:
self._av_objects["curr_frame"], self._av_objects["next_frame"] = (
self._av_objects["next_frame"], None, # iter in stream
)
except OutOfTimeRange as err: # comes from self.curr_frame
self._sup_duration = min(self._sup_duration, timestamp)
self._queues["result"].put((timestamp, err))
break
range_ = self.get_current_range()
frame = self.curr_frame
self._frames[range_[0], range_[1]-range_[0]] = frame
self._queues["result"].put((timestamp, frame))
[docs]
def snapshot(self, timestamp: Fraction) -> av.video.VideoFrame:
"""Return a video frame at the required timestamp.
Examples
--------
>>> from fractions import Fraction
>>> from cutcutcodec.core.io.read_av import FrameAnticipator
>>> from cutcutcodec.utils import get_project_root
>>> video = get_project_root() / "media" / "video" / "intro.webm"
>>> frame_anticipator = FrameAnticipator(video, 0, {})
>>> frame_anticipator.snapshot(Fraction(0)) # doctest: +ELLIPSIS
<av.VideoFrame, pts=0 yuv420p 1280x720 at ...>
>>> frame_anticipator.snapshot(Fraction(1)) # doctest: +ELLIPSIS
<av.VideoFrame, pts=968 yuv420p 1280x720 at ...>
>>> frame_anticipator.snapshot(Fraction(2)) # doctest: +ELLIPSIS
<av.VideoFrame, pts=1969 yuv420p 1280x720 at ...>
>>> frame_anticipator.snapshot(Fraction(3)) # doctest: +ELLIPSIS
<av.VideoFrame, pts=2970 yuv420p 1280x720 at ...>
>>> frame_anticipator.snapshot(Fraction(4)) # doctest: +ELLIPSIS
<av.VideoFrame, pts=3971 yuv420p 1280x720 at ...>
>>>
"""
self._predictor.update(timestamp)
try:
return self._frames[timestamp]
except KeyError:
self._queues["request"].put(timestamp)
while True:
timestamp_receved, err_frame = self._queues["result"].get()
if timestamp_receved == timestamp:
if isinstance(err_frame, Exception):
raise err_frame
return err_frame