Source code for cutcutcodec.core.nn.start

"""Help to store and load the weights."""

import hashlib
import io
import logging
import lzma
import pathlib
import tempfile
import urllib

import torch
import tqdm

from cutcutcodec.utils import get_project_root


[docs] def download(stem: str) -> pathlib.Path: """Attempt to recover network weight on internet. Parameters ---------- stem : str The hexadecimal hash of the model weights. Returns ------- weights : pathlib.Path The path of the downloded weights. Raises ------ FileNotFoundError If the weights doese not exists on the gitlab. ConnectionError If the connection to internet is missing or broken. Examples -------- >>> from cutcutcodec.core.nn.start import download >>> download("631ac8be291fd6c627e6b3b54ce37fdd") PosixPath('/tmp/631ac8be291fd6c627e6b3b54ce37fdd.pt.xz') >>> """ url = ( "https://framagit.org/robinechuca/cutcutcodec/" f"-/raw/main/cutcutcodec/models/{stem}.pt.xz" ) try: with urllib.request.urlopen(url) as req: if req.status != 200: raise ConnectionError( f"please check your internet connection, failed to get {stem}", ) filename = pathlib.Path(tempfile.gettempdir()) / f"{stem}.pt.xz" with open(filename, "wb") as raw: progress = tqdm.tqdm(desc="Download", total=req.length) while data := req.read(1024): raw.write(data) progress.update(len(data)) except urllib.error.HTTPError as err: raise FileNotFoundError(f"the weights {stem} are not existing online") from err return filename
[docs] def load(model: torch.nn.Module, weights: pathlib.Path | str | bytes | None = None): """Load the pretrained weights. Parameters ---------- model : torch.nn.Module The model to be loaded. weights : pathlike, optional The path to the loading weight file with the suffix .pt or .pt.xz """ assert isinstance(model, torch.nn.Module), model.__class__.__name__ # get weights if weights is None: root = pathlib.Path.home() / ".cache" / "cutcutcodec" / "models" root.mkdir(parents=True, exist_ok=True) stem = hashlib.md5(str(model).encode(), usedforsecurity=False).hexdigest() weights = root / f"{stem}.pt" if not weights.exists(): # need to be extracted comp = get_project_root() / "models" / f"{stem}.pt.xz" if not comp.exists(): try: comp = download(stem) except FileNotFoundError: logging.warning("%s weights not found locally or online", comp) return with lzma.open(comp, "rb") as src, open(weights, "wb") as dst: dst.write(src.read()) else: weights = pathlib.Path(weights).expanduser() if not weights.exists(): logging.warning("the weights %s were not founded", comp) return # load weights model.load_state_dict(torch.load(weights, weights_only=True))
[docs] def save( model: torch.nn.Module, weights: pathlib.Path | str | bytes | None = None, ) -> pathlib.Path: """Load the pretrained weights. Parameters ---------- model : torch.nn.Module The model to be loaded. weights : pathlib, optional The path of the recorded file, with the extention .pt.xz Returns ------- weights : pathlib.Path The recorded file. Examples -------- >>> import pathlib, tempfile >>> import torch >>> from cutcutcodec.core.nn.start import save >>> weights = pathlib.Path(tempfile.gettempdir()) / "tmp.pt.xz" >>> class Model(torch.nn.Module): ... def __init__(self): ... super().__init__() ... self.layer = torch.nn.Conv2d(3, 3, kernel_size=3) ... >>> model = Model() >>> save(model, weights) PosixPath('/tmp/tmp.pt.xz') >>> weights.unlink() >>> """ assert isinstance(model, torch.nn.Module), model.__class__.__name__ # save model model.to("cpu") # to avoid loading cuda data on cpu only environement buffer = io.BytesIO() torch.save(model.state_dict(), buffer, _use_new_zipfile_serialization=False) buffer.seek(0) # get record file if weights is None: stem = hashlib.md5(str(model).encode(), usedforsecurity=False).hexdigest() weights = get_project_root() / "models" / f"{stem}.pt.xz" else: weights = pathlib.Path(weights).expanduser() assert weights.suffixes == [".pt", ".xz"], weights # compress with lzma.open(weights, "wb", preset=lzma.PRESET_EXTREME) as file: file.write(buffer.read()) return weights