cutcutcodec.core.nn.start

Help to store and load the weights.

Functions

download(stem)

Attempt to recover network weight on internet.

load(model[, weights])

Load the pretrained weights.

save(model[, weights])

Load the pretrained weights.

Details

cutcutcodec.core.nn.start.download(stem: str) Path[source]

Attempt to recover network weight on internet.

Parameters

stemstr

The hexadecimal hash of the model weights.

Returns

weightspathlib.Path

The path of the downloded weights.

Raises

FileNotFoundError

If the weights doese not exists on the gitlab.

ConnectionError

If the connection to internet is missing or broken.

Examples

>>> from cutcutcodec.core.nn.start import download
>>> download("631ac8be291fd6c627e6b3b54ce37fdd")
PosixPath('/tmp/631ac8be291fd6c627e6b3b54ce37fdd.pt.xz')
>>>
cutcutcodec.core.nn.start.load(model: Module, weights: Path | str | bytes | None = None)[source]

Load the pretrained weights.

Parameters

modeltorch.nn.Module

The model to be loaded.

weightspathlike, optional

The path to the loading weight file with the suffix .pt or .pt.xz

cutcutcodec.core.nn.start.save(model: Module, weights: Path | str | bytes | None = None) Path[source]

Load the pretrained weights.

Parameters

modeltorch.nn.Module

The model to be loaded.

weightspathlib, optional

The path of the recorded file, with the extention .pt.xz

Returns

weightspathlib.Path

The recorded file.

Examples

>>> import pathlib, tempfile
>>> import torch
>>> from cutcutcodec.core.nn.start import save
>>> weights = pathlib.Path(tempfile.gettempdir()) / "tmp.pt.xz"
>>> class Model(torch.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.layer = torch.nn.Conv2d(3, 3, kernel_size=3)
...
>>> model = Model()
>>> save(model, weights)
PosixPath('/tmp/tmp.pt.xz')
>>> weights.unlink()
>>>