Spaces:
Build error
Build error
| import os | |
| from time import time_ns | |
| import gradio as gr | |
| import torch | |
| import requests as rq | |
| from llama_cpp import Llama, LLAMA_SPLIT_MODE_NONE | |
| from transformers import LlamaForCausalLM, LlamaTokenizer | |
| from kgen.generate import tag_gen | |
| from kgen.metainfo import SPECIAL, TARGET | |
| MODEL_PATH = "KBlueLeaf/DanTagGen" | |
| def get_result( | |
| text_model: LlamaForCausalLM, | |
| tokenizer: LlamaTokenizer, | |
| rating: str = "", | |
| artist: str = "", | |
| characters: str = "", | |
| copyrights: str = "", | |
| target: str = "long", | |
| special_tags: list[str] = ["1girl"], | |
| general: str = "", | |
| aspect_ratio: float = 0.0, | |
| blacklist: str = "", | |
| escape_bracket: bool = False, | |
| temperature: float = 1.35, | |
| ): | |
| start = time_ns() | |
| print("=" * 50, "\n") | |
| # Use LLM to predict possible summary | |
| # This prompt allow model itself to make request longer based on what it learned | |
| # Which will be better for preference sim and pref-sum contrastive scorer | |
| prompt = f""" | |
| rating: {rating or '<|empty|>'} | |
| artist: {artist.strip() or '<|empty|>'} | |
| characters: {characters.strip() or '<|empty|>'} | |
| copyrights: {copyrights.strip() or '<|empty|>'} | |
| aspect ratio: {f"{aspect_ratio:.1f}" or '<|empty|>'} | |
| target: {'<|' + target + '|>' if target else '<|long|>'} | |
| general: {", ".join(special_tags)}, {general.strip().strip(",")}<|input_end|> | |
| """.strip() | |
| artist = artist.strip().strip(",").replace("_", " ") | |
| characters = characters.strip().strip(",").replace("_", " ") | |
| copyrights = copyrights.strip().strip(",").replace("_", " ") | |
| special_tags = [tag.strip().replace("_", " ") for tag in special_tags] | |
| general = general.strip().strip(",") | |
| black_list = set( | |
| [tag.strip().replace("_", " ") for tag in blacklist.strip().split(",")] | |
| ) | |
| prompt_tags = special_tags + general.strip().strip(",").split(",") | |
| len_target = TARGET[target] | |
| llm_gen = "" | |
| for llm_gen, extra_tokens in tag_gen( | |
| text_model, | |
| tokenizer, | |
| prompt, | |
| prompt_tags, | |
| len_target, | |
| black_list, | |
| temperature=temperature, | |
| top_p=0.95, | |
| top_k=100, | |
| max_new_tokens=256, | |
| max_retry=5, | |
| ): | |
| yield "", llm_gen, f"Total cost time: {(time_ns()-start)/1e9:.2f}s" | |
| print() | |
| print("-" * 50) | |
| general = f"{general.strip().strip(',')}, {','.join(extra_tokens)}" | |
| tags = general.strip().split(",") | |
| tags = [tag.strip() for tag in tags if tag.strip()] | |
| special = special_tags + [tag for tag in tags if tag in SPECIAL] | |
| tags = [tag for tag in tags if tag not in special] | |
| final_prompt = ", ".join(special) | |
| if characters: | |
| final_prompt += f", \n\n{characters}" | |
| if copyrights: | |
| final_prompt += ", " | |
| if not characters: | |
| final_prompt += "\n\n" | |
| final_prompt += copyrights | |
| if artist: | |
| final_prompt += f", \n\n{artist}" | |
| final_prompt += f""", \n\n{', '.join(tags)}, | |
| masterpiece, newest, absurdres, {rating}""" | |
| print(final_prompt) | |
| print("=" * 50) | |
| if escape_bracket: | |
| final_prompt = ( | |
| final_prompt.replace("[", "\\[") | |
| .replace("]", "\\]") | |
| .replace("(", "\\(") | |
| .replace(")", "\\)") | |
| ) | |
| yield final_prompt, llm_gen, f"Total cost time: {(time_ns()-start)/1e9:.2f}s | Total general tags: {len(special+tags)}" | |
| if __name__ == "__main__": | |
| tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH) | |
| if not os.path.isfile("./model.gguf"): | |
| data = rq.get("https://huggingface.co/KBlueLeaf/DanTagGen/resolve/main/ggml-model-Q6_K.gguf").content | |
| with open("./model.gguf", "wb") as f: | |
| f.write(data) | |
| text_model = Llama( | |
| "./model.gguf", | |
| n_ctx=384, | |
| verbose=False, | |
| ) | |
| def wrapper( | |
| rating: str, | |
| artist: str, | |
| characters: str, | |
| copyrights: str, | |
| target: str, | |
| special_tags: list[str], | |
| general: str, | |
| width: float, | |
| height: float, | |
| blacklist: str, | |
| escape_bracket: bool, | |
| temperature: float = 1.35, | |
| ): | |
| yield from get_result( | |
| text_model, | |
| tokenizer, | |
| rating, | |
| artist, | |
| characters, | |
| copyrights, | |
| target, | |
| special_tags, | |
| general, | |
| width / height, | |
| blacklist, | |
| escape_bracket, | |
| temperature, | |
| ) | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| rating = gr.Radio( | |
| ["safe", "sensitive", "nsfw", "nsfw, explicit"], | |
| label="Rating", | |
| ) | |
| special_tags = gr.Dropdown( | |
| SPECIAL, | |
| value=["1girl"], | |
| label="Special tags", | |
| multiselect=True, | |
| ) | |
| characters = gr.Textbox(label="Characters") | |
| copyrights = gr.Textbox(label="Copyrights(Series)") | |
| artist = gr.Textbox(label="Artist") | |
| target = gr.Radio( | |
| ["very_short", "short", "long", "very_long"], | |
| label="Target length", | |
| ) | |
| with gr.Column(scale=2): | |
| general = gr.TextArea(label="Input your general tags") | |
| black_list = gr.TextArea( | |
| label="tag Black list (seperated by comma)" | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| value=1024, | |
| minimum=256, | |
| maximum=4096, | |
| step=32, | |
| label="Width", | |
| ) | |
| height = gr.Slider( | |
| value=1024, | |
| minimum=256, | |
| maximum=4096, | |
| step=32, | |
| label="Height", | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| value=1.35, | |
| minimum=0.1, | |
| maximum=2, | |
| step=0.05, | |
| label="Temperature", | |
| ) | |
| escape_bracket = gr.Checkbox( | |
| value=False, | |
| label="Escape bracket", | |
| ) | |
| submit = gr.Button("Submit") | |
| with gr.Column(scale=3): | |
| formated_result = gr.TextArea( | |
| label="Final output", lines=14, show_copy_button=True | |
| ) | |
| llm_result = gr.TextArea(label="LLM output", lines=10) | |
| cost_time = gr.Markdown() | |
| submit.click( | |
| wrapper, | |
| inputs=[ | |
| rating, | |
| artist, | |
| characters, | |
| copyrights, | |
| target, | |
| special_tags, | |
| general, | |
| width, | |
| height, | |
| black_list, | |
| temperature, | |
| escape_bracket, | |
| ], | |
| outputs=[ | |
| formated_result, | |
| llm_result, | |
| cost_time, | |
| ], | |
| show_progress=True, | |
| ) | |
| demo.launch() | |