Skip to content

Commit 2f89716

Browse files
committed
Add safety filter nodes
1 parent c967392 commit 2f89716

File tree

7 files changed

+1952
-2
lines changed

7 files changed

+1952
-2
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ __pycache__/
55
!/input/example.png
66
/models/
77
/temp/
8-
/custom_nodes/
98
!custom_nodes/example_node.py.example
109
extra_model_paths.yaml
1110
/.vs

Notebok_workflow_latent_safety.ipynb

Lines changed: 524 additions & 0 deletions
Large diffs are not rendered by default.

Notebok_workflow_safety.ipynb

Lines changed: 391 additions & 0 deletions
Large diffs are not rendered by default.

comfy/clip_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(self, config_dict, dtype, device, operations):
121121
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
122122
embed_dim = config_dict["hidden_size"]
123123
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
124-
self.text_projection.weight.copy_(torch.eye(embed_dim))
124+
self.text_projection.weight.data.copy_(torch.eye(embed_dim))
125125
self.dtype = dtype
126126

127127
def get_input_embeddings(self):
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
2+
from sklearn.metrics.pairwise import cosine_similarity
3+
import torch
4+
import open_clip
5+
6+
class LatentSafetyFilter:
7+
"""
8+
A example node
9+
10+
Class methods
11+
-------------
12+
INPUT_TYPES (dict):
13+
Tell the main program input parameters of nodes.
14+
IS_CHANGED:
15+
optional method to control when the node is re executed.
16+
17+
Attributes
18+
----------
19+
RETURN_TYPES (`tuple`):
20+
The type of each element in the output tulple.
21+
RETURN_NAMES (`tuple`):
22+
Optional: The name of each output in the output tulple.
23+
FUNCTION (`str`):
24+
The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
25+
OUTPUT_NODE ([`bool`]):
26+
If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example.
27+
The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected.
28+
Assumed to be False if not present.
29+
CATEGORY (`str`):
30+
The category the node should appear in the UI.
31+
execute(s) -> tuple || None:
32+
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
33+
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
34+
"""
35+
def __init__(self):
36+
pass
37+
38+
@classmethod
39+
def INPUT_TYPES(s):
40+
"""
41+
Return a dictionary which contains config for all input fields.
42+
Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT".
43+
Input types "INT", "STRING" or "FLOAT" are special values for fields on the node.
44+
The type can be a list for selection.
45+
46+
Returns: `dict`:
47+
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
48+
- Value input_fields (`dict`): Contains input fields config:
49+
* Key field_name (`string`): Name of a entry-point method's argument
50+
* Value field_config (`tuple`):
51+
+ First value is a string indicate the type of field or a list for selection.
52+
+ Secound value is a config for type "INT", "STRING" or "FLOAT".
53+
"""
54+
return {
55+
"required": {
56+
"samples": ("LATENT", ),
57+
"safety_filter": ("STRING", {
58+
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
59+
"default": "nsfw"
60+
}),
61+
"threshold": ("FLOAT", {
62+
"default": 0.2,
63+
"min": 0.0,
64+
"max": 1.0,
65+
"step": 0.01,
66+
"round": 0.001, #The value represeting the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
67+
"display": "number"}),
68+
"int_field": ("INT", {
69+
"default": 0,
70+
"min": 0, #Minimum value
71+
"max": 4096, #Maximum value
72+
"step": 64, #Slider's step
73+
"display": "number" # Cosmetic only: display as "number" or "slider"
74+
}),
75+
"print_to_screen": (["enable", "disable"],),
76+
},
77+
}
78+
79+
RETURN_TYPES = ("LATENT",)
80+
#RETURN_NAMES = ("image_output_name",)
81+
82+
FUNCTION = "test"
83+
84+
#OUTPUT_NODE = False
85+
86+
CATEGORY = "Safety"
87+
88+
def get_model_info(self, model_ID, device):
89+
model = CLIPModel.from_pretrained(model_ID).to(device)
90+
processor = CLIPProcessor.from_pretrained(model_ID)
91+
tokenizer = CLIPTokenizer.from_pretrained(model_ID)
92+
return model, processor, tokenizer
93+
94+
def test(self, samples, safety_filter, int_field, threshold, print_to_screen):
95+
models = {'B-8': {'model_name':'Latent-ViT-B-8-512',
96+
'pretrained':'/dlabdata1/wendler/models/latent-clip-b-8.pt'},
97+
'B-4-plus':{'model_name':'Latent-ViT-B-4-512-plus',
98+
'pretrained':'/dlabdata1/wendler/models/latent-clip-b-4-plus.pt'}}
99+
size = 'B-4-plus'
100+
model_name = models[size]['model_name']
101+
pretrained = models[size]['pretrained']
102+
model_latent, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
103+
tokenizer_latent = open_clip.get_tokenizer(model_name)
104+
105+
image_features = model_latent.encode_image(samples["samples"])
106+
text_features = model_latent.encode_text(tokenizer_latent([f"an image of {safety_filter}", f"an image of no {safety_filter}"]))
107+
108+
image_features /= image_features.norm(dim=-1, keepdim=True)
109+
text_features /= text_features.norm(dim=-1, keepdim=True)
110+
111+
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
112+
print(text_probs)
113+
114+
for i, sample in enumerate(samples["samples"]):
115+
116+
117+
if text_probs[i][0].item() > threshold:
118+
samples["samples"][i].zero_()
119+
print("Sample", i, "processed: Set to zero")
120+
else:
121+
print("Sample", i, "processed: Not set to zero")
122+
123+
print("Probability:", text_probs[0][0].item())
124+
print("Threshold:", threshold)
125+
#print("Safety Filter:", safety_filters[i])
126+
127+
return (samples,)
128+
if text_probs[0][0].item() > threshold:
129+
samples["samples"].zero_()
130+
print("THIS")
131+
else:
132+
print("NOT THIS")
133+
print(text_probs[0][0].item())
134+
print(threshold)
135+
print(safety_filter)
136+
return (samples, )
137+
#text_features = model_latent.encode_text(captions.cuda())
138+
print(image_features.shape)
139+
image_features_np = image_features.detach().numpy()
140+
text_features_np = text_features.detach().numpy()
141+
142+
similarity_score = cosine_similarity(image_features_np, text_features_np)
143+
print(f"Similarity ({text}):\t{similarity_score}")
144+
145+
146+
147+
148+
"""
149+
The node will always be re executed if any of the inputs change but
150+
this method can be used to force the node to execute again even when the inputs don't change.
151+
You can make this node return a number or a string. This value will be compared to the one returned the last time the node was
152+
executed, if it is different the node will be executed again.
153+
This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash
154+
changes between executions the LoadImage node is executed again.
155+
"""
156+
#@classmethod
157+
#def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen):
158+
# return ""
159+
160+
# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension
161+
# WEB_DIRECTORY = "./somejs"
162+
163+
# A dictionary that contains all nodes you want to export with their names
164+
# NOTE: names should be globally unique
165+
NODE_CLASS_MAPPINGS = {
166+
"LatentSafetyFilter": LatentSafetyFilter
167+
168+
}
169+
170+
# A dictionary that contains the friendly/humanly readable titles for the nodes
171+
NODE_DISPLAY_NAME_MAPPINGS = {
172+
"LatentSafetyFilter": "Latent Safety Filter"
173+
}

custom_nodes/safety_filter.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
2+
from sklearn.metrics.pairwise import cosine_similarity
3+
import torch
4+
5+
class SafetyFilter:
6+
"""
7+
A example node
8+
9+
Class methods
10+
-------------
11+
INPUT_TYPES (dict):
12+
Tell the main program input parameters of nodes.
13+
IS_CHANGED:
14+
optional method to control when the node is re executed.
15+
16+
Attributes
17+
----------
18+
RETURN_TYPES (`tuple`):
19+
The type of each element in the output tulple.
20+
RETURN_NAMES (`tuple`):
21+
Optional: The name of each output in the output tulple.
22+
FUNCTION (`str`):
23+
The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
24+
OUTPUT_NODE ([`bool`]):
25+
If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example.
26+
The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected.
27+
Assumed to be False if not present.
28+
CATEGORY (`str`):
29+
The category the node should appear in the UI.
30+
execute(s) -> tuple || None:
31+
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
32+
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
33+
"""
34+
def __init__(self):
35+
pass
36+
37+
@classmethod
38+
def INPUT_TYPES(s):
39+
"""
40+
Return a dictionary which contains config for all input fields.
41+
Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT".
42+
Input types "INT", "STRING" or "FLOAT" are special values for fields on the node.
43+
The type can be a list for selection.
44+
45+
Returns: `dict`:
46+
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
47+
- Value input_fields (`dict`): Contains input fields config:
48+
* Key field_name (`string`): Name of a entry-point method's argument
49+
* Value field_config (`tuple`):
50+
+ First value is a string indicate the type of field or a list for selection.
51+
+ Secound value is a config for type "INT", "STRING" or "FLOAT".
52+
"""
53+
return {
54+
"required": {
55+
"image": ("IMAGE",),
56+
"safety_filter": ("STRING", {
57+
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
58+
"default": "nsfw"
59+
}),
60+
"threshold": ("FLOAT", {
61+
"default": 0.2,
62+
"min": 0.0,
63+
"max": 1.0,
64+
"step": 0.01,
65+
"round": 0.001, #The value represeting the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
66+
"display": "number"}),
67+
"int_field": ("INT", {
68+
"default": 0,
69+
"min": 0, #Minimum value
70+
"max": 4096, #Maximum value
71+
"step": 64, #Slider's step
72+
"display": "number" # Cosmetic only: display as "number" or "slider"
73+
}),
74+
"print_to_screen": (["enable", "disable"],),
75+
},
76+
}
77+
78+
RETURN_TYPES = ("IMAGE",)
79+
#RETURN_NAMES = ("image_output_name",)
80+
81+
FUNCTION = "test"
82+
83+
#OUTPUT_NODE = False
84+
85+
CATEGORY = "Safety"
86+
87+
def get_model_info(self, model_ID, device):
88+
model = CLIPModel.from_pretrained(model_ID).to(device)
89+
processor = CLIPProcessor.from_pretrained(model_ID)
90+
tokenizer = CLIPTokenizer.from_pretrained(model_ID)
91+
return model, processor, tokenizer
92+
93+
def test(self, image, safety_filter, int_field, threshold, print_to_screen):
94+
device = "cuda" if torch.cuda.is_available() else "cpu"
95+
96+
model_IDs = ["openai/clip-vit-base-patch32", "openai/clip-vit-large-patch14"]
97+
model_ID = model_IDs[1]
98+
model_clip, processor, tokenizer = self.get_model_info(model_ID, device)
99+
100+
#url = "http://images.cocodataset.org/val2017/000000039769.jpg"
101+
102+
103+
inputs = processor(text=safety_filter, images=image, return_tensors="pt", padding=True).to(device)
104+
outputs = model_clip(**inputs)
105+
106+
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
107+
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
108+
109+
print(probs)
110+
print(logits_per_image)
111+
112+
processed_images = []
113+
for i, logits_per_image1 in enumerate(logits_per_image):
114+
# Do some processing on the image, in this example I just invert it
115+
if (logits_per_image1 / 100.0 > threshold):
116+
image[i] = 0.0 * image[i]
117+
118+
processed_images.append(image[i])
119+
return (processed_images, i)
120+
if print_to_screen == "enable":
121+
print(f"""Your input contains:
122+
string_field aka input text: {safety_filter}
123+
int_field: {int_field}
124+
float_field: {threshold}
125+
""")
126+
#do some processing on the image, in this example I just invert it
127+
if (logits_per_image.item() /100.0 > threshold):
128+
image = 0.0 * image
129+
return (image,)
130+
131+
"""
132+
The node will always be re executed if any of the inputs change but
133+
this method can be used to force the node to execute again even when the inputs don't change.
134+
You can make this node return a number or a string. This value will be compared to the one returned the last time the node was
135+
executed, if it is different the node will be executed again.
136+
This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash
137+
changes between executions the LoadImage node is executed again.
138+
"""
139+
#@classmethod
140+
#def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen):
141+
# return ""
142+
143+
# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension
144+
# WEB_DIRECTORY = "./somejs"
145+
146+
# A dictionary that contains all nodes you want to export with their names
147+
# NOTE: names should be globally unique
148+
NODE_CLASS_MAPPINGS = {
149+
"SafetyFilter": SafetyFilter
150+
}
151+
152+
# A dictionary that contains the friendly/humanly readable titles for the nodes
153+
NODE_DISPLAY_NAME_MAPPINGS = {
154+
"SafetyFilter": "Safety Filter"
155+
}

0 commit comments

Comments
 (0)