gradio: if user exits browser or tab, gradio not cleaning up process/threads

Describe the bug

https://huggingface.co/docs/transformers/internal/generation_utils#transformers.TextIteratorStreamer

When using this, if user exits tab or closes browser, generation continues in background indefinitely.

Is there an existing issue for this?

  • I have searched the existing issues

Reproduction

import gradio as gr

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
    tok = AutoTokenizer.from_pretrained("gpt2")
    model = AutoModelForCausalLM.from_pretrained("gpt2")
    streamer = TextIteratorStreamer(tok)


    def respond(message, chat_history):
        from threading import Thread

        # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
        inputs = tok([message], return_tensors="pt")
        generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1000)
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()
        bot_message = ""
        chat_history.append([message, bot_message])
        for new_text in streamer:
            bot_message += new_text
            chat_history[-1][1] = bot_message
            yield chat_history
        return

    msg.submit(respond, [msg, chatbot], chatbot, queue=True)
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue(concurrency_count=1)
demo.launch()

Enter “The sky is” and soon after close the tab. The generation will continue. For larger models, this is problem since gradio thinks the user is gone, queue is open, but now threads will overlap.

Also, note that adding a raise StopIteration() has no effect on the model.generate(), it only terminates the respond generation. So GPU usage continues in background. i.e.

import gradio as gr

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
    tok = AutoTokenizer.from_pretrained("gpt2")
    model = AutoModelForCausalLM.from_pretrained("gpt2")
    streamer = TextIteratorStreamer(tok)


    def respond(message, chat_history):
        from threading import Thread

        # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
        inputs = tok([message], return_tensors="pt")
        generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1000)
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()
        bot_message = ""
        chat_history.append([message, bot_message])
        for new_text in streamer:
            bot_message += new_text
            chat_history[-1][1] = bot_message
            if len(bot_message) > 50:
                raise StopIteration()
            yield chat_history
        return

    msg.submit(respond, [msg, chatbot], chatbot, queue=True)
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue(concurrency_count=1)
demo.launch()

This will lead to a termination of output to gradio chatbot, but continue to use GPU in other thread. It doesn’t help to join the thread, that just holds there until generation finishes. Some methods for excepting a thread exist, but that doesn’t solve first issue and is not standard to do.

Screenshot

image

Logs

None required.

System Info

latest, chrome browser.

Severity

blocking all usage of gradio

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Comments: 18 (13 by maintainers)

Commits related to this issue

Most upvoted comments

What’s really needed is a callback like streamlit for each state once disconnect reached. This way one can properly handle things like cuda models, to put the model off GPU and clear torch cache, but only for that specific state.

I confirmed at least manually that this would work with above approach. During disconnect, I found my model state and did:

               app.state_holder.session_data[session_hash][1]['model'].cpu()
               import torch
                import gc
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    torch.cuda.ipc_collect()
                    gc.collect()

before popping out the item because cuda would remain in cache unless moved to cpu first and then cleared cache.

i.e.

        def clear_torch_cache():
            import torch
            import gc
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.ipc_collect()
                gc.collect()

        from starlette.websockets import WebSocketDisconnect
        @app.websocket("/ws")
        async def websocket_endpoint(websocket: WebSocket):
            a = app.auth
            await websocket.accept()
            try:
                while True:
                    # You can also process incoming messages here
                    data = await websocket.receive_text()
                    print(f"Message from client: {data}")
            except WebSocketDisconnect:
                hashes_cleared = []
                for session_hash, state in app.state_holder.session_data.items():
                    hashes_cleared.append(session_hash)
                    states = app.state_holder.session_data[session_hash]
                    if isinstance(states[1], dict) and 'model' in states[1] and hasattr(states[1]['model'], 'cpu'):
                        app.state_holder.session_data[session_hash][1]['model'].cpu()
                        clear_torch_cache()
                    app.state_holder.session_data.pop(session_hash)
            print("Client disconnected: %s" % hashes_cleared, flush=True)

I basically need to know in gradio when user closes tab. How can I do that?

+1, I really need that too I have a gradio interface that allow users to upload their files to the storage but the files still there even when the users exit the interface and the storage get full, we need like a change event for gradio.State() trigger when the leave the page it should help a lot of users not only me im i right?