|
| 1 | +from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer |
| 2 | +from sklearn.metrics.pairwise import cosine_similarity |
| 3 | +import torch |
| 4 | +import open_clip |
| 5 | + |
| 6 | +class LatentSafetyFilter: |
| 7 | + """ |
| 8 | + A example node |
| 9 | +
|
| 10 | + Class methods |
| 11 | + ------------- |
| 12 | + INPUT_TYPES (dict): |
| 13 | + Tell the main program input parameters of nodes. |
| 14 | + IS_CHANGED: |
| 15 | + optional method to control when the node is re executed. |
| 16 | +
|
| 17 | + Attributes |
| 18 | + ---------- |
| 19 | + RETURN_TYPES (`tuple`): |
| 20 | + The type of each element in the output tulple. |
| 21 | + RETURN_NAMES (`tuple`): |
| 22 | + Optional: The name of each output in the output tulple. |
| 23 | + FUNCTION (`str`): |
| 24 | + The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute() |
| 25 | + OUTPUT_NODE ([`bool`]): |
| 26 | + If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example. |
| 27 | + The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected. |
| 28 | + Assumed to be False if not present. |
| 29 | + CATEGORY (`str`): |
| 30 | + The category the node should appear in the UI. |
| 31 | + execute(s) -> tuple || None: |
| 32 | + The entry point method. The name of this method must be the same as the value of property `FUNCTION`. |
| 33 | + For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`. |
| 34 | + """ |
| 35 | + def __init__(self): |
| 36 | + pass |
| 37 | + |
| 38 | + @classmethod |
| 39 | + def INPUT_TYPES(s): |
| 40 | + """ |
| 41 | + Return a dictionary which contains config for all input fields. |
| 42 | + Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT". |
| 43 | + Input types "INT", "STRING" or "FLOAT" are special values for fields on the node. |
| 44 | + The type can be a list for selection. |
| 45 | +
|
| 46 | + Returns: `dict`: |
| 47 | + - Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required` |
| 48 | + - Value input_fields (`dict`): Contains input fields config: |
| 49 | + * Key field_name (`string`): Name of a entry-point method's argument |
| 50 | + * Value field_config (`tuple`): |
| 51 | + + First value is a string indicate the type of field or a list for selection. |
| 52 | + + Secound value is a config for type "INT", "STRING" or "FLOAT". |
| 53 | + """ |
| 54 | + return { |
| 55 | + "required": { |
| 56 | + "samples": ("LATENT", ), |
| 57 | + "safety_filter": ("STRING", { |
| 58 | + "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node |
| 59 | + "default": "nsfw" |
| 60 | + }), |
| 61 | + "threshold": ("FLOAT", { |
| 62 | + "default": 0.2, |
| 63 | + "min": 0.0, |
| 64 | + "max": 1.0, |
| 65 | + "step": 0.01, |
| 66 | + "round": 0.001, #The value represeting the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. |
| 67 | + "display": "number"}), |
| 68 | + "int_field": ("INT", { |
| 69 | + "default": 0, |
| 70 | + "min": 0, #Minimum value |
| 71 | + "max": 4096, #Maximum value |
| 72 | + "step": 64, #Slider's step |
| 73 | + "display": "number" # Cosmetic only: display as "number" or "slider" |
| 74 | + }), |
| 75 | + "print_to_screen": (["enable", "disable"],), |
| 76 | + }, |
| 77 | + } |
| 78 | + |
| 79 | + RETURN_TYPES = ("LATENT",) |
| 80 | + #RETURN_NAMES = ("image_output_name",) |
| 81 | + |
| 82 | + FUNCTION = "test" |
| 83 | + |
| 84 | + #OUTPUT_NODE = False |
| 85 | + |
| 86 | + CATEGORY = "Safety" |
| 87 | + |
| 88 | + def get_model_info(self, model_ID, device): |
| 89 | + model = CLIPModel.from_pretrained(model_ID).to(device) |
| 90 | + processor = CLIPProcessor.from_pretrained(model_ID) |
| 91 | + tokenizer = CLIPTokenizer.from_pretrained(model_ID) |
| 92 | + return model, processor, tokenizer |
| 93 | + |
| 94 | + def test(self, samples, safety_filter, int_field, threshold, print_to_screen): |
| 95 | + models = {'B-8': {'model_name':'Latent-ViT-B-8-512', |
| 96 | + 'pretrained':'/dlabdata1/wendler/models/latent-clip-b-8.pt'}, |
| 97 | + 'B-4-plus':{'model_name':'Latent-ViT-B-4-512-plus', |
| 98 | + 'pretrained':'/dlabdata1/wendler/models/latent-clip-b-4-plus.pt'}} |
| 99 | + size = 'B-4-plus' |
| 100 | + model_name = models[size]['model_name'] |
| 101 | + pretrained = models[size]['pretrained'] |
| 102 | + model_latent, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained) |
| 103 | + tokenizer_latent = open_clip.get_tokenizer(model_name) |
| 104 | + |
| 105 | + image_features = model_latent.encode_image(samples["samples"]) |
| 106 | + text_features = model_latent.encode_text(tokenizer_latent([f"an image of {safety_filter}", f"an image of no {safety_filter}"])) |
| 107 | + |
| 108 | + image_features /= image_features.norm(dim=-1, keepdim=True) |
| 109 | + text_features /= text_features.norm(dim=-1, keepdim=True) |
| 110 | + |
| 111 | + text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) |
| 112 | + print(text_probs) |
| 113 | + |
| 114 | + for i, sample in enumerate(samples["samples"]): |
| 115 | + |
| 116 | + |
| 117 | + if text_probs[i][0].item() > threshold: |
| 118 | + samples["samples"][i].zero_() |
| 119 | + print("Sample", i, "processed: Set to zero") |
| 120 | + else: |
| 121 | + print("Sample", i, "processed: Not set to zero") |
| 122 | + |
| 123 | + print("Probability:", text_probs[0][0].item()) |
| 124 | + print("Threshold:", threshold) |
| 125 | + #print("Safety Filter:", safety_filters[i]) |
| 126 | + |
| 127 | + return (samples,) |
| 128 | + if text_probs[0][0].item() > threshold: |
| 129 | + samples["samples"].zero_() |
| 130 | + print("THIS") |
| 131 | + else: |
| 132 | + print("NOT THIS") |
| 133 | + print(text_probs[0][0].item()) |
| 134 | + print(threshold) |
| 135 | + print(safety_filter) |
| 136 | + return (samples, ) |
| 137 | + #text_features = model_latent.encode_text(captions.cuda()) |
| 138 | + print(image_features.shape) |
| 139 | + image_features_np = image_features.detach().numpy() |
| 140 | + text_features_np = text_features.detach().numpy() |
| 141 | + |
| 142 | + similarity_score = cosine_similarity(image_features_np, text_features_np) |
| 143 | + print(f"Similarity ({text}):\t{similarity_score}") |
| 144 | + |
| 145 | + |
| 146 | + |
| 147 | + |
| 148 | + """ |
| 149 | + The node will always be re executed if any of the inputs change but |
| 150 | + this method can be used to force the node to execute again even when the inputs don't change. |
| 151 | + You can make this node return a number or a string. This value will be compared to the one returned the last time the node was |
| 152 | + executed, if it is different the node will be executed again. |
| 153 | + This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash |
| 154 | + changes between executions the LoadImage node is executed again. |
| 155 | + """ |
| 156 | + #@classmethod |
| 157 | + #def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen): |
| 158 | + # return "" |
| 159 | + |
| 160 | +# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension |
| 161 | +# WEB_DIRECTORY = "./somejs" |
| 162 | + |
| 163 | +# A dictionary that contains all nodes you want to export with their names |
| 164 | +# NOTE: names should be globally unique |
| 165 | +NODE_CLASS_MAPPINGS = { |
| 166 | + "LatentSafetyFilter": LatentSafetyFilter |
| 167 | + |
| 168 | +} |
| 169 | + |
| 170 | +# A dictionary that contains the friendly/humanly readable titles for the nodes |
| 171 | +NODE_DISPLAY_NAME_MAPPINGS = { |
| 172 | + "LatentSafetyFilter": "Latent Safety Filter" |
| 173 | +} |
0 commit comments