mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-18 20:14:34 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b22ccd9dd0 |
@@ -52,17 +52,12 @@ def groupby_prefix_and_trim(prefix, d):
|
|||||||
def cast_torch_tensor(fn):
|
def cast_torch_tensor(fn):
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def inner(model, *args, **kwargs):
|
def inner(model, *args, **kwargs):
|
||||||
device = kwargs.pop('_device', next(model.parameters()).device)
|
device = next(model.parameters()).device
|
||||||
cast_device = kwargs.pop('_cast_device', True)
|
|
||||||
|
|
||||||
kwargs_keys = kwargs.keys()
|
kwargs_keys = kwargs.keys()
|
||||||
all_args = (*args, *kwargs.values())
|
all_args = (*args, *kwargs.values())
|
||||||
split_kwargs_index = len(all_args) - len(kwargs_keys)
|
split_kwargs_index = len(all_args) - len(kwargs_keys)
|
||||||
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
|
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
|
||||||
|
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
||||||
if cast_device:
|
|
||||||
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
|
|
||||||
|
|
||||||
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
|
||||||
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user