Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ldm/models/diffusion/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ def __init__(self, model, schedule="linear", **kwargs):

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
if attr.device != torch.device("cuda") and torch.cuda.is_available():
attr = attr.to(torch.device("cuda"))
else:
attr = attr.to(torch.device("cpu"))
setattr(self, name, attr)

def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
Expand Down
4 changes: 3 additions & 1 deletion ldm/models/diffusion/plms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ def __init__(self, model, schedule="linear", **kwargs):

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
if attr.device != torch.device("cuda") and torch.cuda.is_available():
attr = attr.to(torch.device("cuda"))
else:
attr = attr.to(torch.device("cpu"))
setattr(self, name, attr)

def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
Expand Down
14 changes: 7 additions & 7 deletions ldm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def forward(self, batch, key=None):

class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda" if torch.cuda.is_available() else "cpu"):
super().__init__()
self.device = device
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer))

Expand All @@ -52,11 +52,11 @@ def encode(self, x):

class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
self.device = device
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length

Expand All @@ -80,12 +80,12 @@ def decode(self, text):
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
device="cuda" if torch.cuda.is_available() else "cpu", use_tokenizer=True, embedding_dropout=0.0):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
self.device = device
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout)
Expand Down Expand Up @@ -139,7 +139,7 @@ class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
def __init__(self, version='ViT-L/14', device="cuda" if torch.cuda.is_available() else "cpu", max_length=77, n_repeat=1, normalize=True):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
self.device = device
Expand Down
4 changes: 2 additions & 2 deletions notebook_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def load_model_from_config(config, ckpt):
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.cuda()
model.cuda() if torch.cuda.is_available() else model.cpu()
model.eval()
return {"model": model}, global_step

Expand Down Expand Up @@ -117,7 +117,7 @@ def get_cond(mode, selected_path):
c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1.

c = c.to(torch.device("cuda"))
c = c.to(torch.device("cuda")) if torch.cuda.is_available() else c.to(torch.device("cpu"))
example["LR_image"] = c
example["image"] = c_up

Expand Down
7 changes: 5 additions & 2 deletions scripts/knn2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)

model.cuda()
model.cuda() if torch.cuda.is_available() else model.cpu()
model.eval()
return model

Expand Down Expand Up @@ -124,6 +124,8 @@ def load_retriever(self, version='ViT-L/14', ):
model = FrozenClipImageEmbedder(model=version)
if torch.cuda.is_available():
model.cuda()
else:
model.cpu()
model.eval()
return model

Expand Down Expand Up @@ -358,7 +360,8 @@ def __call__(self, x, n):
uc = None
if searcher is not None:
nn_dict = searcher(c, opt.knn)
c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
nn_embeddings = torch.from_numpy(nn_dict['nn_embeddings']).cuda() if torch.cuda.is_available() else torch.from_numpy(nn_dict['nn_embeddings']).cpu()
c = torch.cat([c, nn_embeddings], dim=1)
if opt.scale != 1.0:
uc = torch.zeros_like(c)
if isinstance(prompts, tuple):
Expand Down
2 changes: 1 addition & 1 deletion scripts/sample_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def get_parser():
def load_model_from_config(config, sd):
model = instantiate_from_config(config)
model.load_state_dict(sd,strict=False)
model.cuda()
model.cuda() if torch.cuda.is_available() else model.cpu()
model.eval()
return model

Expand Down
2 changes: 1 addition & 1 deletion scripts/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)

model.cuda()
model.cuda() if torch.cuda.is_available() else model.cpu()
model.eval()
return model

Expand Down