Skip to content

Commit c33afad

Browse files
committed
custom model support in backend
1 parent 6154d84 commit c33afad

File tree

6 files changed

+403
-5
lines changed

6 files changed

+403
-5
lines changed
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
build_test/
2+
optimized_stable_diffusion/
3+
4+
5+
HF_weights/
6+
outputs/
7+
8+
# Byte-compiled / optimized / DLL files
9+
__pycache__/
10+
*.py[cod]
11+
*$py.class
12+
13+
# C extensions
14+
*.so
15+
16+
# Distribution / packaging
17+
.Python
18+
build/
19+
develop-eggs/
20+
dist/
21+
downloads/
22+
eggs/
23+
.eggs/
24+
lib/
25+
lib64/
26+
parts/
27+
sdist/
28+
var/
29+
wheels/
30+
share/python-wheels/
31+
*.egg-info/
32+
.installed.cfg
33+
*.egg
34+
MANIFEST
35+
36+
# PyInstaller
37+
# Usually these files are written by a python script from a template
38+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
39+
*.manifest
40+
*.spec
41+
42+
# Installer logs
43+
pip-log.txt
44+
pip-delete-this-directory.txt
45+
46+
# Unit test / coverage reports
47+
htmlcov/
48+
.tox/
49+
.nox/
50+
.coverage
51+
.coverage.*
52+
.cache
53+
nosetests.xml
54+
coverage.xml
55+
*.cover
56+
*.py,cover
57+
.hypothesis/
58+
.pytest_cache/
59+
cover/
60+
61+
# Translations
62+
*.mo
63+
*.pot
64+
65+
# Django stuff:
66+
*.log
67+
local_settings.py
68+
db.sqlite3
69+
db.sqlite3-journal
70+
71+
# Flask stuff:
72+
instance/
73+
.webassets-cache
74+
75+
# Scrapy stuff:
76+
.scrapy
77+
78+
# Sphinx documentation
79+
docs/_build/
80+
81+
# PyBuilder
82+
.pybuilder/
83+
target/
84+
85+
# Jupyter Notebook
86+
.ipynb_checkpoints
87+
88+
# IPython
89+
profile_default/
90+
ipython_config.py
91+
92+
# pyenv
93+
# For a library or package, you might want to ignore these files since the code is
94+
# intended to run in multiple environments; otherwise, check them in:
95+
# .python-version
96+
97+
# pipenv
98+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
100+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
101+
# install all needed dependencies.
102+
#Pipfile.lock
103+
104+
# poetry
105+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106+
# This is especially recommended for binary packages to ensure reproducibility, and is more
107+
# commonly ignored for libraries.
108+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109+
#poetry.lock
110+
111+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
112+
__pypackages__/
113+
114+
# Celery stuff
115+
celerybeat-schedule
116+
celerybeat.pid
117+
118+
# SageMath parsed files
119+
*.sage.py
120+
121+
# Environments
122+
.env
123+
.venv
124+
env/
125+
venv/
126+
ENV/
127+
env.bak/
128+
venv.bak/
129+
130+
# Spyder project settings
131+
.spyderproject
132+
.spyproject
133+
134+
# Rope project settings
135+
.ropeproject
136+
137+
# mkdocs documentation
138+
/site
139+
140+
# mypy
141+
.mypy_cache/
142+
.dmypy.json
143+
dmypy.json
144+
145+
# Pyre type checker
146+
.pyre/
147+
148+
# pytype static type analyzer
149+
.pytype/
150+
151+
# Cython debug symbols
152+
cython_debug/
153+
154+
# PyCharm
155+
# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
156+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
157+
# and can be added to the global gitignore or merged into this file. For a more nuclear
158+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
159+
#.idea/
160+
161+
162+
# General
163+
.DS_Store
164+
.AppleDouble
165+
.LSOverride
166+
167+
# Icon must end with two \r
168+
Icon
169+
170+
# Thumbnails
171+
._*
172+
173+
# Files that might appear in the root of a volume
174+
.DocumentRevisions-V100
175+
.fseventsd
176+
.Spotlight-V100
177+
.TemporaryItems
178+
.Trashes
179+
.VolumeIcon.icns
180+
.com.apple.timemachine.donotpresent
181+
182+
# Directories potentially created on remote AFP share
183+
.AppleDB
184+
.AppleDesktop
185+
Network Trash Folder
186+
Temporary Items
187+
.apdisk
188+
189+
190+
191+
pretrained_weights/

backends/model_converter/constants.py

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from fake_torch import fake_torch_load_zipped
2+
import json
3+
import numpy as np
4+
from constants import SD_SHAPES
5+
import sys
6+
7+
# python convert_model.py "/Users/divamgupta/Downloads/hollie-mengert.ckpt" "/Users/divamgupta/Downloads/hollie-mengert.tdict"
8+
9+
# pyinstaller convert_model.py --onefile --noconfirm --clean # build using intel machine so that its cross platform lol
10+
11+
checkpoint_filename = sys.argv[1]
12+
out_filename = sys.argv[2]
13+
14+
#TODO add MD5s
15+
16+
_HEADER_BYTES = [42, 10 , 8, 42] + [0]*20
17+
18+
19+
s = 24
20+
21+
torch_weights = fake_torch_load_zipped(open(checkpoint_filename, "rb"))
22+
keys_info = {}
23+
out_file = open( out_filename , "wb")
24+
25+
out_file.write(bytes(_HEADER_BYTES))
26+
27+
for k in torch_weights['state_dict']:
28+
assert k in SD_SHAPES , k
29+
np_arr = torch_weights['state_dict'][k]
30+
key_bytes = np_arr.tobytes()
31+
shape = list(np_arr.shape)
32+
assert tuple(shape) == SD_SHAPES[k], (k , shape , SD_SHAPES[k] )
33+
dtype = str(np_arr.dtype)
34+
e = s + len(key_bytes)
35+
out_file.write(key_bytes)
36+
keys_info[k] = {"start": s , "end" : e , "shape": shape , "dtype" : dtype }
37+
s = e
38+
39+
for k in SD_SHAPES:
40+
if 'model_ema' in k or 'betas' in k or 'alphas' in k or 'posterior_' in k:
41+
continue
42+
assert k in keys_info , k
43+
44+
json_start = s
45+
info_json = bytes( json.dumps(keys_info) , 'ascii')
46+
json_end = s + len(info_json)
47+
48+
out_file.write(info_json)
49+
50+
out_file.seek(5)
51+
out_file.write(np.array(json_start).astype('long').tobytes())
52+
53+
out_file.seek(14)
54+
out_file.write(np.array(json_end).astype('long').tobytes())
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import pickle
2+
import numpy as np
3+
import math
4+
5+
def prod(x):
6+
return math.prod(x)
7+
8+
def my_unpickle(fb0):
9+
key_prelookup = {}
10+
class HackTensor:
11+
def __new__(cls, *args):
12+
#print(args)
13+
ident, storage_type, obj_key, location, obj_size = args[0][0:5]
14+
assert ident == 'storage'
15+
16+
assert prod(args[2]) == obj_size
17+
ret = np.zeros(args[2], dtype=storage_type)
18+
key_prelookup[obj_key] = (storage_type, obj_size, ret, args[2], args[3])
19+
return ret
20+
21+
class HackParameter:
22+
def __new__(cls, *args):
23+
#print(args)
24+
pass
25+
26+
class Dummy:
27+
pass
28+
29+
class MyPickle(pickle.Unpickler):
30+
def find_class(self, module, name):
31+
#print(module, name)
32+
if name == 'FloatStorage':
33+
return np.float32
34+
if name == 'LongStorage':
35+
return np.int64
36+
if name == 'HalfStorage':
37+
return np.float16
38+
if module == "torch._utils":
39+
if name == "_rebuild_tensor_v2":
40+
return HackTensor
41+
elif name == "_rebuild_parameter":
42+
return HackParameter
43+
else:
44+
try:
45+
return pickle.Unpickler.find_class(self, module, name)
46+
except Exception:
47+
return Dummy
48+
49+
def persistent_load(self, pid):
50+
return pid
51+
52+
return MyPickle(fb0).load(), key_prelookup
53+
54+
def fake_torch_load_zipped(fb0, load_weights=True):
55+
import zipfile
56+
with zipfile.ZipFile(fb0, 'r') as myzip:
57+
with myzip.open('archive/data.pkl') as myfile:
58+
ret = my_unpickle(myfile)
59+
if load_weights:
60+
for k,v in ret[1].items():
61+
with myzip.open(f'archive/data/{k}') as myfile:
62+
if v[2].dtype == "object":
63+
print(f"issue assigning object on {k}")
64+
continue
65+
np.copyto(v[2], np.frombuffer(myfile.read(), v[2].dtype).reshape(v[3]))
66+
return ret[0]
67+
68+
def fake_torch_load(b0):
69+
import io
70+
import struct
71+
72+
# convert it to a file
73+
fb0 = io.BytesIO(b0)
74+
75+
if b0[0:2] == b"\x50\x4b":
76+
return fake_torch_load_zipped(fb0)
77+
78+
# skip three junk pickles
79+
pickle.load(fb0)
80+
pickle.load(fb0)
81+
pickle.load(fb0)
82+
83+
ret, key_prelookup = my_unpickle(fb0)
84+
85+
# create key_lookup
86+
key_lookup = pickle.load(fb0)
87+
key_real = [None] * len(key_lookup)
88+
for k,v in key_prelookup.items():
89+
key_real[key_lookup.index(k)] = v
90+
91+
# read in the actual data
92+
for storage_type, obj_size, np_array, np_shape, np_strides in key_real:
93+
ll = struct.unpack("Q", fb0.read(8))[0]
94+
assert ll == obj_size
95+
bytes_size = {np.float32: 4, np.int64: 8}[storage_type]
96+
mydat = fb0.read(ll * bytes_size)
97+
np.copyto(np_array, np.frombuffer(mydat, storage_type).reshape(np_shape))
98+
99+
# numpy stores its strides in bytes
100+
real_strides = tuple([x*bytes_size for x in np_strides])
101+
np_array.strides = real_strides
102+
103+
return ret

0 commit comments

Comments
 (0)