starlette: Consume request body in middleware is problematic

from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.middleware.base import BaseHTTPMiddleware


class SampleMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        _ = await request.form()
        return await call_next(request)


app = Starlette()


@app.route('/test', methods=['POST'])
async def test(request):
    _ = await request.form()  # blocked, because middleware already consumed request body
    return PlainTextResponse('Hello, world!')


app.add_middleware(SampleMiddleware)
$ uvicorn test:app --reload
$ curl -d "a=1" http://127.0.0.1:8000/test
# request is blocked

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Reactions: 17
  • Comments: 35 (8 by maintainers)

Commits related to this issue

Most upvoted comments

Coming from FastAPI issue referenced above.

@tomchristie I don’t understand the issue about consuming the request in a middleware, could you explain this point ? In fact, I have the need (which is current where I work) to log every requests received by my production server. Is there a better place than a middleware to do it and avoid duplicating code in every endpoint ? (I would like something like this https://github.com/Rhumbix/django-request-logging)

For now, I found this workaround (but that’s not very pretty):

async def set_body(request: Request, body: bytes):
    async def receive() -> Message:
        return {"type": "http.request", "body": body}

    request._receive = receive

async def get_body(request: Request) -> bytes:
    body = await request.body()
    set_body(request, body)
    return body

but there will always be cases where it can’t work (eg. if you stream the request data, then it’s just not going to be available anymore later down the line)

I kind of disagree with your example. In fact, stream data is not stored by default, but stream metadata (is the stream closed) are; there will be an understandable error raised if someone try to stream twice, and that is enough imho. That’s why if the body is cached, the stream consumption has to be cached too.

“Consume request body in middleware is problematic”

Indeed. Consuming request data in middleware is problematic. Not just to Starlette, but generally, everywhere.

On the whole you should avoid doing so if at all possible.

There’s some work we could do to make it work better, but there will always be cases where it can’t work (eg. if you stream the request data, then it’s just not going to be available anymore later down the line).

There are plenty use cases like @wyfo mention. In my case i’m using JWT Signature to check all the data integrity, so i decode the token and then compare the decoded result with the body + query + path params of the request. I don’t know any better way of doing this

is there an update for this? it seems like this is something a lot of people need for very legit use cases ( mainly logging it seems)

is there a plan to allow consuming the body in a middleware ? i see the body is cached on the request object _body, is it possible to cache it on the scope so it accessible from everywhere after it is read?

any other solution would also be ok, but i do feel this i needed

@JHBalaji how are you logging the request body to context, are you not running into the same issue when trying to access the body in the incoming request middleware?

I had the same problem but I also needed to consume the response. If anyone has this problem here is a solution I used:

from typing import Callable, Awaitable, Tuple, Dict, List

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import Scope, Message


class RequestWithBody(Request):
    def __init__(self, scope: Scope, body: bytes) -> None:
        super().__init__(scope, self._receive)
        self._body = body
        self._body_returned = False

    async def _receive(self) -> Message:
        if self._body_returned:
            return {"type": "http.disconnect"}
        else:
            self._body_returned = True
            return {"type": "http.request", "body": self._body, "more_body": False}


class CustomMiddleware(BaseHTTPMiddleware):
    async def dispatch(  # type: ignore
        self, request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]]
    ) -> Response:
            request_body_bytes = await request.body()
            
            # use request values

            request_with_body = RequestWithBody(request.scope, request_body_bytes)
            response = await call_next(request_with_body)
            response_content_bytes, response_headers, response_status = await self._get_response_params(response)
            
            # use response values

            return Response(response_content_bytes, response_status, response_headers)

    async def _get_response_params(self, response: StreamingResponse) -> Tuple[bytes, Dict[str, str], int]:
        response_byte_chunks: List[bytes] = []
        response_status: List[int] = []
        response_headers: List[Dict[str, str]] = []

        async def send(message: Message) -> None:
            if message["type"] == "http.response.start":
                response_status.append(message["status"])
                response_headers.append({k.decode("utf8"): v.decode("utf8") for k, v in message["headers"]})
            else:
                response_byte_chunks.append(message["body"])

        await response.stream_response(send)
        content = b"".join(response_byte_chunks)
        return content, response_headers[0], response_status[0]

I’m posting in this 4-year old issue this with the hope that someone will see it and it will lead to a better understanding of the problem and ultimately a fix, because I have run into this problem many times, and because it doesn’t throw an error, it just hangs the program it always takes me too long to figure out what’s going on.

I think there is some confusion about what the problem is. I don’t believe people are confused about how to wrap receive with a function that consumes the request body, nor do I think people are misguided in needing to consume the full request body in a middleware. Sure, it removes the benefits of a streaming architecture, but that’s an acceptable trade-off in many cases. For example, in the case of HMAC signature verification: The first thing I want to do when my application receives a request is verify the signature. If the signature is bad, I want the request to be rejected before it uses anymore resources (database connections, logging, etc.).

If I only wrap the receive function in another function that does signature verification, then the signature won’t be checked until the request has gotten all the way to the endpoint handler where Starlette begins to consume the request body.

The problem comes when you try to consume the request body (i.e., call receive) directly inside of the middleware, like this:

async def __call__(self, scope, receive, send):
    body = b""

    while True:
        message = await receive()
        chunk: bytes = message.get("body", b"")
        body += chunk
        if message.get("more_body", False) is False:
            break

receive is now exhausted, and this line hangs because it’s waiting for something that will never come.

For me, all that’s needed is a tiny, undocumented property on state that can signal to Starlette that the body has been consumed and it shouldn’t wait for anything else:

async def __call__(self, scope, receive, send):
    body = b""
    # ... consume the body    
    scope["state"]["previously_consumed_body"] = body
    self.app(scope, receive, send)

And in Request.body:

async def body(self) -> bytes:
    if (previously_consumed_body := self.scope["state"].get("previously_consumed_body")):
        return previously_consumed_body
    
    # else, the rest of the function is the same

Sure, it’s hacky. But it’s just one line of code, doesn’t need to be documented (i.e. comes with the understanding that using it is risky), adds virtually no request overhead, but will save me and probably thousands of other developers from having to refactor their middleware stack or routes, fork this repository, or use another library.

If someone needs a fix now, you can do something like this to trick Starlette (this is similar to other workarounds already mentioned above). Note that this code isn’t production-ready:

async def __call__(self, scope, receive, send):
    body = b""
    # ... consume the body

    async def dummy_receive():
        return {
            "type": "http.request",
            "body": body,
            "more_body": False,
        }

    await self.app(scope, dummy_receive, send)

Currently, this is one of the better options: https://fastapi.tiangolo.com/advanced/custom-request-and-route/#accessing-the-request-body-in-an-exception-handler but unfortunately there still isn’t a great way to do this as far as I’m aware.

After reading through this, it seems the issue isn’t with consuming the entire body, starlette actually will pump that async generator and save it to the request for subsequent requests: https://github.com/encode/starlette/blob/e45c5793611bfec606cd9c5d56887ddc67f77c38/starlette/requests.py#L225

The issue comes when getting form data. That function uses the underlying stream, not the body that may or may not be cached already: https://github.com/encode/starlette/blob/e45c5793611bfec606cd9c5d56887ddc67f77c38/starlette/requests.py#L246

A call to json() properly uses the cached data, but form() doesn’t.

Give this a shot in a controller or middleware:

await request.body()
await request.body()
await request.json()

and I bet that works fine. Toss in a await request.form() and it will give you that Stream consumed error

My work around that I’m using in conjunction with FastAPI:

from collections import Callable
from io import BytesIO

from fastapi.routing import APIRoute
from starlette.datastructures import FormData
from starlette.formparsers import MultiPartParser, FormParser
from starlette.requests import Request
from starlette.responses import Response

try:
    from multipart.multipart import parse_options_header
except ImportError:  # pragma: nocover
    parse_options_header = None


class AsyncIteratorWrapper:
    """
    Small helper to turn BytesIO into async-able iterator
    """

    def __init__(self, bytes_: bytes):
        super().__init__()
        self._it = BytesIO(bytes_)

    def __aiter__(self):
        return self

    async def __anext__(self):
        try:
            value = next(self._it)
        except StopIteration:
            raise StopAsyncIteration
        return value


class PreProcessedFormRequest(Request):
    async def form(self) -> FormData:
        if not hasattr(self, "_form"):
            assert (
                parse_options_header is not None
            ), "The `python-multipart` library must be installed to use form parsing."
            content_type_header = self.headers.get("Content-Type")
            content_type, options = parse_options_header(content_type_header)

            body_iter = AsyncIteratorWrapper(await self.body())
            if content_type == b"multipart/form-data":
                multipart_parser = MultiPartParser(self.headers, body_iter)
                self._form = await multipart_parser.parse()
            elif content_type == b"application/x-www-form-urlencoded":
                form_parser = FormParser(self.headers, body_iter)
                self._form = await form_parser.parse()
            else:
                self._form = FormData()
        return self._form


class PreProcessedFormRoute(APIRoute):
    def get_route_handler(self) -> Callable:
        original_route_handler = super().get_route_handler()

        async def custom_route_handler(request: Request) -> Response:
            request = PreProcessedFormRequest(request.scope, request.receive)
            return await original_route_handler(request)

        return custom_route_handler

Then installed into FastAPI like:

app = FastAPI()
app.router.route_class = PreProcessedFormRoute

@JivanRoquet yes but the FastAPI documentation does have what you might be looking for.

@tomchristie we’re currently using the above workaround using the code below:

class LoggingMiddleware(BaseHTTPMiddleware):
    def __init__(self, app):
        super().__init__(app)
        self.logger = logging.getLogger()

    async def set_body(self, request):
        receive_ = await request.receive()

        async def receive() -> Message:
            return receive_

        request._receive = receive

    async def dispatch(self, request, call_next):
        req_uuid = request.state.correlation_id
        await self.set_body(request)
        body = await request.body()
        self.logger.info(
            "Request",
            extra={
                "uuid": req_uuid,
                "type": "api-request",
                "method": str(request.method).upper(),
                "url": str(request.url),
                "payload": body.decode("utf-8"),
            },
        )
        start_time = time.time()
        response = await call_next(request)
        process_time = (time.time() - start_time) * 1000
        formatted_process_time = "{0:.2f}".format(process_time)
        self.logger.info(
            "Response sent",
            extra={
                "uuid": req_uuid,
                "type": "api-response",
                "url": f"[{str(request.method).upper()}] {str(request.url)}",
                "code": response.status_code,
                "elapsed_time": f"{formatted_process_time}ms",
            },
        )
        return response

this is working great for us. What would you think about maybe adding some kind of mixin class to starlette? maybe like a RequestLoggingMiddleware class, so that it makes it easier to log request payloads from the API, even though, I agree with you that in general, this is problematic and should be avoided?

I think we can do it in this way:

  1. when you do Request(scope, receive, send) it will set self instance into scope.
  2. next time, when you instantiate a new request object, it will get the request instance from the scope (achievable by implementing __new__). This says we will always have the same Request object.