WrapperModule

class dnois.torch.WrapperModule(func: Callable, *args, **kwargs)

A class to wrap a function as a torch.nn.Module.

>>> s = WrapperModule(torch.sum, dim=(-2, -1))
>>> x = torch.rand(4)
>>> s(x)  # equivalent to torch.sum(x, dim=(-2, -1))
Parameters:
  • func (Callable) – The function to be wrapped.

  • args – Positional arguments to be passed to func when this module is called.

  • kwargs – Keyword arguments to be passed to func when this module is called.

forward(*args, **kwargs)

Call the wrapped function func.

Parameters:
  • args – Additional positional arguments to be passed to func.

  • kwargs – Additional keyword arguments to be passed to func.

Returns:

The returned value of the wrapped function.

Return type:

Any