|
1 | | -import argparse, os, sys, datetime, glob, importlib |
| 1 | +import argparse, os, sys, datetime, glob |
2 | 2 | from omegaconf import OmegaConf |
3 | 3 | import numpy as np |
4 | 4 | from PIL import Image |
|
10 | 10 | from pytorch_lightning.trainer import Trainer |
11 | 11 | from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor |
12 | 12 | from pytorch_lightning.utilities.distributed import rank_zero_only |
13 | | - |
14 | | -def get_obj_from_str(string, reload=False): |
15 | | - module, cls = string.rsplit(".", 1) |
16 | | - if reload: |
17 | | - module_imp = importlib.import_module(module) |
18 | | - importlib.reload(module_imp) |
19 | | - return getattr(importlib.import_module(module, package=None), cls) |
20 | | - |
| 13 | +from taming.util import instantiate_from_config |
21 | 14 |
|
22 | 15 | def get_parser(**parser_kwargs): |
23 | 16 | def str2bool(v): |
@@ -110,12 +103,6 @@ def nondefault_trainer_args(opt): |
110 | 103 | return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) |
111 | 104 |
|
112 | 105 |
|
113 | | -def instantiate_from_config(config): |
114 | | - if not "target" in config: |
115 | | - raise KeyError("Expected key `target` to instantiate.") |
116 | | - return get_obj_from_str(config["target"])(**config.get("params", dict())) |
117 | | - |
118 | | - |
119 | 106 | class WrappedDataset(Dataset): |
120 | 107 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" |
121 | 108 | def __init__(self, dataset): |
|
0 commit comments