WrapperModule¶
- class dnois.torch.WrapperModule(func: Callable, *args, **kwargs)¶
A class to wrap a function as a
torch.nn.Module
.>>> s = WrapperModule(torch.sum, dim=(-2, -1)) >>> x = torch.rand(4) >>> s(x) # equivalent to torch.sum(x, dim=(-2, -1))
- Parameters:
func (Callable) – The function to be wrapped.
args – Positional arguments to be passed to
func
when this module is called.kwargs – Keyword arguments to be passed to
func
when this module is called.
- forward(*args, **kwargs)¶
Call the wrapped function
func
.- Parameters:
args – Additional positional arguments to be passed to
func
.kwargs – Additional keyword arguments to be passed to
func
.
- Returns:
The returned value of the wrapped function.
- Return type:
Any