torchgear.module¶
- class torchgear.module.DeviceMixIn[source]¶
Bases:
TensorAsDelegateSome
torch.Tensors may be associated to objects of the class (e.g. buffers and parameters oftorch.nn.Module) derived from this class. They are assumed to be on the same device, which is the value ofdevice.
- class torchgear.module.DtypeMixIn[source]¶
Bases:
TensorAsDelegateSome
torch.Tensors may be associated to objects of the class (e.g. buffers and parameters oftorch.nn.Module) derived from this class. They are assumed to have same data type, which is the value ofdtype.
- class torchgear.module.FreezeParamMixIn[source]¶
Bases:
object- freeze(name: str | Sequence[str] = None)[source]¶
Equivalent to
self.set_optimizable(name, False). Seeset_optimizable().
- set_optimizable(name: str | Sequence[str] = None, optimizable: bool = True)[source]¶
Specify whether a parameter is optimizable.
- Parameters:
name (str) – Name of the parameter. If
None, all parameters will be set. It follows the same convention astorch.nn.Module.get_parameter().optimizable (bool) – Whether the specified parameter is optimizable. Default:
True.
- class torchgear.module.TensorAsDelegate[source]¶
Bases:
objectA mixin that creates
torch.Tensors on the same device and dtype as a delegate tensor associated withself.When mixed into a
torch.nn.Modulesubclass, the delegate is taken from the first registered parameter or buffer. If none exists, a stub buffer is registered lazily on first use. Subclasses that are not derived fromtorch.nn.Modulemust implement_delegate().- arange(*args, **kwargs) Tensor[source]¶
Call
torch.arangewith the device and dtype of the delegate tensor.
- linspace(*args, **kwargs) Tensor[source]¶
Call
torch.linspacewith the device and dtype of the delegate tensor.
- new_full(size, fill_value, **kwargs) Tensor[source]¶
Call
torch.Tensor.new_fullon the delegate tensor.
- class torchgear.module.TensorContainerMixIn[source]¶
Bases:
DeviceMixIn,DtypeMixInA mixin combining
DeviceMixInandDtypeMixIn.Tensors associated with objects of a class derived from this mixin (e.g. parameters and buffers of a
torch.nn.Module) are assumed to share the samedeviceanddtype._check_consistency()and_cast()validate and align incoming tensors against both attributes.
- class torchgear.module.TorchgearModule(*args: Any, **kwargs: Any)[source]¶
Bases:
Module,TensorContainerMixIn,FreezeParamMixInBase class for torchgear models.
- class torchgear.module.WrapperModule(func: Callable, *args, **kwargs)[source]¶
Bases:
ModuleA class to wrap a function as a
torch.nn.Module.>>> s = WrapperModule(torch.sum, dim=(-2, -1)) >>> x = torch.rand(4) >>> s(x) # equivalent to torch.sum(x, dim=(-2, -1))
- Parameters:
func (Callable) – The function to be wrapped.
args – Positional arguments to be passed to
funcwhen this module is called.kwargs – Keyword arguments to be passed to
funcwhen this module is called.