Code source de cutcutcodec.core.nn.start

#!/usr/bin/env python3

"""Help to store and load the weights."""

import hashlib
import io
import logging
import lzma
import pathlib
import typing

import torch

from cutcutcodec.utils import get_project_root


[docs] def load(model: torch.nn.Module, weights: typing.Optional[pathlib.Path | str | bytes] = None): """Load the pretrained weights. Parameters ---------- model : torch.nn.Module The model to be loaded. weights : pathlike, optional The path to the loading weight file with the suffix .pt or .pt.xz """ assert isinstance(model, torch.nn.Module), model.__class__.__name__ # get weights if weights is None: root = pathlib.Path.home() / ".cache" / "cutcutcodec" / "models" root.mkdir(parents=True, exist_ok=True) stem = hashlib.md5(str(model).encode(), usedforsecurity=False).hexdigest() weights = root / f"{stem}.pt" if not weights.exists(): # need to be extracted comp = get_project_root() / "models" / f"{stem}.pt.xz" if not comp.exists(): logging.warning("the weights %s were not founded", comp) return with lzma.open(comp, "rb") as src, open(weights, "wb") as dst: dst.write(src.read()) else: weights = pathlib.Path(weights).expanduser() if not weights.exists(): logging.warning("the weights %s were not founded", comp) return # load weights model.load_state_dict(torch.load(weights, weights_only=True))
[docs] def save( model: torch.nn.Module, weights: typing.Optional[pathlib.Path | str | bytes] = None ) -> pathlib.Path: """Load the pretrained weights. Parameters ---------- model : torch.nn.Module The model to be loaded. weights : pathlib, optional The path of the recorded file, with the extention .pt.xz Returns ------- weights : pathlib.Path The recorded file. Examples -------- >>> import pathlib, tempfile >>> import torch >>> from cutcutcodec.core.nn.start import save >>> weights = pathlib.Path(tempfile.gettempdir()) / "tmp.pt.xz" >>> class Model(torch.nn.Module): ... def __init__(self): ... super().__init__() ... self.layer = torch.nn.Conv2d(3, 3, kernel_size=3) ... >>> model = Model() >>> save(model, weights) PosixPath('/tmp/tmp.pt.xz') >>> weights.unlink() >>> """ assert isinstance(model, torch.nn.Module), model.__class__.__name__ # save model model.to("cpu") # to avoid loading cuda data on cpu only environement buffer = io.BytesIO() torch.save(model.state_dict(), buffer) buffer.seek(0) # get record file if weights is None: stem = hashlib.md5(str(model).encode(), usedforsecurity=False).hexdigest() weights = get_project_root() / "models" / f"{stem}.pt.xz" else: weights = pathlib.Path(weights).expanduser() assert weights.suffixes == [".pt", ".xz"], weights # compress with lzma.open(weights, "wb", preset=lzma.PRESET_EXTREME) as file: file.write(buffer.read()) return weights