Source code for cutcutcodec.core.opti.parallel.threading
"""Thread utils."""
import numbers
import os
import threading
import torch
[docs]
def get_num_threads(threads: numbers.Integral) -> int:
"""Return the number of threads."""
assert isinstance(threads, numbers.Integral), threads.__class__.__name__
if threads == 0:
return max(2, os.cpu_count()//2) if threading.current_thread().name == "MainThread" else 1
if threads < 0:
return max(2, os.cpu_count()//2)
return int(threads)
[docs]
class TorchThreads:
"""Context manager to set the number of torch threads.
Examples
--------
>>> import torch
>>> from cutcutcodec.core.opti.parallel.threading import TorchThreads
>>> (t := torch.get_num_threads()) != 1
True
>>> with TorchThreads(1):
... torch.get_num_threads()
...
1
>>> torch.get_num_threads() == t
True
>>>
"""
def __init__(self, threads: numbers.Integral):
"""Initialise the thread setter.
Parameters
----------
threads : int
The number of threads, same as ``get_num_threads``.
"""
self.threads = get_num_threads(threads)
self.torch_threads = None
# self.torch_interop_threads = None
def __enter__(self) -> int:
"""Set the threading torch context."""
self.torch_threads = torch.get_num_threads()
# self.torch_interop_threads = torch.get_num_interop_threads()
torch.set_num_threads(self.threads)
# torch.set_num_interop_threads(self.threads)
return self.threads
def __exit__(self, *_):
"""Reset the previous threads."""
torch.set_num_threads(self.torch_threads)
# torch.set_num_interop_threads(self.torch_interop_threads)