torchgear.module

class torchgear.module.DeviceMixIn[source]

Bases: TensorAsDelegate

Some torch.Tensor s may be associated to objects of the class (e.g. buffers and parameters of torch.nn.Module) derived from this class. They are assumed to be on the same device, which is the value of device.

property device: device

Device of this object.

Type:

torch.device

class torchgear.module.DtypeMixIn[source]

Bases: TensorAsDelegate

Some torch.Tensor s may be associated to objects of the class (e.g. buffers and parameters of torch.nn.Module) derived from this class. They are assumed to have same data type, which is the value of dtype.

property dtype: dtype

Data type of this object.

Type:

torch.dtype

class torchgear.module.FreezeParamMixIn[source]

Bases: object

freeze(name: str | Sequence[str] = None)[source]

Equivalent to self.set_optimizable(name, False). See set_optimizable().

set_optimizable(name: str | Sequence[str] = None, optimizable: bool = True)[source]

Specify whether a parameter is optimizable.

Parameters:
  • name (str) – Name of the parameter. If None, all parameters will be set. It follows the same convention as torch.nn.Module.get_parameter().

  • optimizable (bool) – Whether the specified parameter is optimizable. Default: True.

unfreeze(name: str | Sequence[str] = None)[source]

Equivalent to self.set_optimizable(name, True). See set_optimizable().

class torchgear.module.TensorAsDelegate[source]

Bases: object

A mixin that creates torch.Tensor s on the same device and dtype as a delegate tensor associated with self.

When mixed into a torch.nn.Module subclass, the delegate is taken from the first registered parameter or buffer. If none exists, a stub buffer is registered lazily on first use. Subclasses that are not derived from torch.nn.Module must implement _delegate().

arange(*args, **kwargs) Tensor[source]

Call torch.arange with the device and dtype of the delegate tensor.

linspace(*args, **kwargs) Tensor[source]

Call torch.linspace with the device and dtype of the delegate tensor.

new_empty(size, **kwargs) Tensor[source]

Call torch.Tensor.new_empty on the delegate tensor.

new_full(size, fill_value, **kwargs) Tensor[source]

Call torch.Tensor.new_full on the delegate tensor.

new_ones(size, **kwargs) Tensor[source]

Call torch.Tensor.new_ones on the delegate tensor.

new_tensor(data, **kwargs) Tensor[source]

Call torch.Tensor.new_tensor on the delegate tensor.

new_zeros(size, **kwargs) Tensor[source]

Call torch.Tensor.new_zeros on the delegate tensor.

rand(*args, **kwargs) Tensor[source]

Call torch.rand with the device and dtype of the delegate tensor.

randn(*args, **kwargs) Tensor[source]

Call torch.randn with the device and dtype of the delegate tensor.

class torchgear.module.TensorContainerMixIn[source]

Bases: DeviceMixIn, DtypeMixIn

A mixin combining DeviceMixIn and DtypeMixIn.

Tensors associated with objects of a class derived from this mixin (e.g. parameters and buffers of a torch.nn.Module) are assumed to share the same device and dtype. _check_consistency() and _cast() validate and align incoming tensors against both attributes.

class torchgear.module.TorchgearModule(*args: Any, **kwargs: Any)[source]

Bases: Module, TensorContainerMixIn, FreezeParamMixIn

Base class for torchgear models.

class torchgear.module.WrapperModule(func: Callable, *args, **kwargs)[source]

Bases: Module

A class to wrap a function as a torch.nn.Module.

>>> s = WrapperModule(torch.sum, dim=(-2, -1))
>>> x = torch.rand(4)
>>> s(x)  # equivalent to torch.sum(x, dim=(-2, -1))
Parameters:
  • func (Callable) – The function to be wrapped.

  • args – Positional arguments to be passed to func when this module is called.

  • kwargs – Keyword arguments to be passed to func when this module is called.

forward(*args, **kwargs)[source]

Call the wrapped function func.

Parameters:
  • args – Additional positional arguments to be passed to func.

  • kwargs – Additional keyword arguments to be passed to func.

Returns:

The returned value of the wrapped function.

Return type:

Any