angle-uparrow-clockwisearrow-counterclockwisearrow-down-uparrow-leftatcalendarcard-listchatcheckenvelopefolderhouseinfo-circlepencilpeoplepersonperson-fillperson-plusphoneplusquestion-circlesearchtagtrashx

Flask RESTful API request parameter validation with Marshmallow schemas

Create separate schemas for path, query and body request parameters and validate them with a single function.

30 March 2021 Updated 22 April 2021
In API, Flask
post main image

When you build a RESTful API the first thing you do is define the status codes and error responses. RFC 7807 'Problem Details for HTTP APIs' specifies the minimum parameters you should return. If you did not looked into this, I suggest you do. Of course you often want to include more details about what went wrong. APIs are for developers and we do want to make it easy for them to understand why a call failed.

When you build an API with Flask it is almost impossible to ignore a serialization / deserialization package like Marshmallow. And together with the package APISpec it is easy to document your API.

When you call your API methods the first thing to do is to validate the request parameters. Below I show you how I did this with Marshmallow.

Input validation with Marshmallow

Writing an API is different from writing a web application. For a Flask web application with forms, we use the WTForms package to retrieve and validate the request parameters. For an API, we use schemas. Schemas describe how to communicate with an API and allow us to document our API with Swagger, for example.

The Marshmallow package not only lets us deserialize input parameters but also contains a large set of validators.

An example model class: City

In this article I will be using the (SQLAlchemy) model class City.

class City(Base):
    __tablename__ = 'city'

    id = Column(Integer, primary_key=True)
    name = Column(String(100), server_default='')

We want the following operations:

  • Get a list of cities, with pagination and filter
  • Get a city by id
  • Create a city
  • Update a city by id
  • Delete a city by id

Then we have the following requests:

GET      /cities?page=1&per_page=10&search=an
GET      /cities/4
POST     /cities, city name in request body
PUT      /cities/5, city name in request body
DELETE   /cities/7

And our methods look like:

@blueprint_cities.route('', methods=['GET'])
def cities_list():
    ...

@blueprint_cities.route('/<int:city_id>', methods=['GET'])
def cities_get(city_id):
    ...

@blueprint_cities.route('', methods=['POST'])
def cities_create():
    ...

@blueprint_cities.route('/<int:city_id>', methods=['PUT'])
def cities_update(city_id):
    ...

@blueprint_cities.route('/<int:city_id>', methods=['DELETE'])
def cities_delete(city_id):
    ...

Flask and request parameters

There are three types of request parameters:

  • Path parameters: city_id
  • Query parameters: page, per_page, search
  • Body parameters: name

In Flask we can access these parameters using the request object:

  • request.view_args
  • request.args
  • request.form

Translation to Marshmallow schemas

We start by creating base classes, schemas, for the three types of request parameters:

class RequestPathParamsSchema(Schema):
    pass

class RequestQueryParamsSchema(Schema):
    pass

class RequestBodyParamsSchema(Schema):
    pass

Then we create schemas for these parameters:

class CitiesRequestQueryParamsSchema(PaginationQueryParamsSchema, NameFilterQueryParamsSchema):
    pass
    
class CitiesRequestPathParamsSchema(RequestPathParamsSchema):
    city_id = fields.Int(
        description='The id of the city.',
        validate=validate.Range(min=1, max=9999),
        required=True)

class CitiesCreateRequestBodyParamsSchema(RequestBodyParamsSchema):
    name = fields.Str(
        description='The name of the city',
        validate=validate.Length(min=2, max=40),
        required=True)

class CitiesUpdateRequestBodyParamsSchema(RequestBodyParamsSchema):
    name = fields.Str(
        description='The name of the city',
        validate=validate.Length(min=2, max=40),
        required=True)

This looks very much like when using WTForms. Above I referenced two classes PaginationQueryParamsSchema and NameFilterQueryParamsSchema. Here they are:

class PaginationQueryParamsSchema(RequestQueryParamsSchema):
    page = fields.Int(
        missing=PAGINATION_PAGE_VALUE_DEFAULT,
        description='Pagination page number, first page is 1.',
        validate=validate.Range(min=PAGINATION_PAGE_VALUE_MIN, max=PAGINATION_PAGE_VALUE_MAX),
        required=False)
    per_page = fields.Int(
        missing=PAGINATION_PER_PAGE_VALUE_DEFAULT,
        description='Pagination items per page.',
        validate=validate.Range(min=PAGINATION_PER_PAGE_VALUE_MIN, max=PAGINATION_PER_PAGE_VALUE_MAX),
        required=False)

class NameFilterQueryParamsSchema(RequestQueryParamsSchema):
    name = fields.Str(
        description='The (part of the) name to search for',
        validate=validate.Length(min=2, max=40),
        required=False)

For deserializing and validation of objects, Marshmallow gives us the method load(). This method returns a dictionary of field names mapped to values. If a validation error occurs, a ValidationError is raised.

For example, to check the request Url path parameter city_id in the GET, UPDATE and DELETE methods, we call load() as follows:

try:
    result = CitiesRequestPathParamsSchema().load(request.view_args)
except ValidationError as err:
    ...

When no errors are found, the result looks like:

{'city_id': 5}

We can do the same for the request Url query parameters and request body parameters.

A single function checking all schemas

With the above we still have to call the load() method for all the schemas in a method. In our example only the request PUT method has two schemas, but in a more realistic API you can expect more schemas. I was looking for a way to combine this to avoid repetition and reduce code.

Because our Cities schemas are inherited from our base classes RequestPathParamsSchema, RequestQueryParamsSchema and RequestBodyParamsSchema, we can use these to select the parameters to be passed to the load() method. The Python isinstance() function is used for this. It also returns True when checking if an object is inherited from a base class.

I created a helper class APIUtils with a method request_schemas_load() where I can pass one or more schemas to be validated and loaded.

class APIUtils:
    ...

    @classmethod
    def schema_check(cls, schema=None, json_data=None, title=None):
        # load and validate
        try:
            return schema.load(data=json_data, partial=True)

        except ValidationError as err:
            raise APIError(
                status_code=400,
                title=title,
                messages=err.messages,
                data=err.data,
                valid_data=err.valid_data,
            )

    @classmethod
    def request_schemas_load(cls, schemas):
        if not isinstance(schemas, list):
            schemas = [schemas]

        result_path, result_query, result_body = {}, {}, {}
        for schema in schemas:
            if isinstance(schema, api_spec.RequestPathParamsSchema):
                # path params
                result_path.update(cls.schema_check(
                    schema=schema,
                    json_data=request.view_args,
                    title='One or more request url path parameters did not validate'))

            if isinstance(schema, api_spec.RequestQueryParamsSchema):
                # query params
                result_query.update(cls.schema_check(
                    schema=schema, 
                    json_data=request.args,
                    title='One or more request url query parameters did not validate'))

            if isinstance(schema, api_spec.RequestBodyParamsSchema):
                # body params
                result_body.update(cls.schema_check(
                    schema=schema, 
                    json_data=request.get_json(),
                    title='One or more request body parameters did not validate'))

        return {
            'path': result_path,
            'query': result_query,
            'body': result_body,
        }

    @classmethod
    def get_by_id_or_404(cls, res, res_id, res_name, res_id_name):
        obj = app_db.session.query(res).get(res_id)
        if obj is None:
            raise APIError(
                status_code=404,
                title='The requested resource could not be found',
                messages={
                    res_name: [
                        'Not found',
                    ]
                },
                data={
                    res_id_name: res_id,
                },
            )
        return obj

In the code above, we iterate through the schemas and load and validate them one-by-one. The result is a dictionary of the Marshmallow schema.load() result dictionaries. The assumption is we do not have different parameters with the same name within a request type (path, query, body).

Title is the RFC 7807 title. This a message for all kinds of request parameter errors.

In case of an error, an exception is raised. The ValidationError messages data is also included in the response
and includes details of the error:

  • err.messages
  • err.data
  • err.valid_data

I also add the method get_by_id_or_404() to show you the response in case a resource was not found.

Code for the City update method

With APIUtils helper method request_schemas_load() we can now write the code for the City update method:

@blueprint_cities.route('/<int:city_id>', methods=['PUT'])
def cities_update(city_id):
    result = APIUtils.request_schemas_load([
        CitiesRequestPathParamsSchema(),
        CitiesUpdateRequestBodyParamsSchema()])

    city = APIUtils.get_by_id_or_404(City, city_id, 'City', 'city_id')

    for k, v in result['body'].items():
        setattr(city, k, v)
    app_db.session.commit()

    return jsonify({
        'data': CitiesResponseSchema().dump(city)
    }), 200

In the above I use the method APIUtils.get_by_id_or_404() to get a single City by id. If the city does not exist, an APIError is raised. When there are no errors, and the city is found, I update the body parameters with the values in the request.

Some API error responses

I use Curl to show some results. I already loaded the database with some cities.

First we get two cities:

curl -i "http://127.0.0.1:5000/api/v1/cities?page=1&per_page=2"

The response:

HTTP/1.0 200 OK
Content-Type: application/json
Content-Length: 247
Server: Werkzeug/1.0.1 Python/3.8.5
Date: Mon, 29 Mar 2021 15:51:39 GMT

{
  "data": [
    {
      "id": 1, 
      "name": "Beijing"
    }, 
    {
      "id": 2, 
      "name": "Berlin"
    }
  ], 
  "meta": {
    "count": 2, 
    "limit": 2, 
    "offset": 0, 
    "page": 1, 
    "per_page": 2, 
    "total": 11
  }
}

Now let us update the city name Berlin to Hamburg, but with mistakes.

Example 1: bad path parameter, city_id = 0

curl -i -X PUT -H "Content-Type: application/json" -d '{"name":"Hamburg"}' http://127.0.0.1:5000/api/cities/0

The response:

HTTP/1.0 400 BAD REQUEST
Content-Type: application/json
Content-Length: 247
Server: Werkzeug/1.0.1 Python/3.8.5
Date: Mon, 29 Mar 2021 16:13:22 GMT

{
  "status": 400, 
  "title": "One or more request url path parameters did not validate", 
  "messages": {
    "city_id": [
      "Must be greater than or equal to 1 and less than or equal to 9999."
    ]
  }, 
  "data": {
    "city_id": 0
  }
}

Example 2: bad body parameter, name = H

curl -i -X PUT -H "Content-Type: application/json" -d '{"name":"H"}' http://127.0.0.1:5000/api/v1/cities/2

The response:

HTTP/1.0 400 BAD REQUEST
Content-Type: application/json
Content-Length: 205
Server: Werkzeug/1.0.1 Python/3.8.5
Date: Mon, 29 Mar 2021 16:15:46 GMT

{
  "status": 400, 
  "title": "One or more request body parameters did not validate", 
  "messages": {
    "name": [
      "Length must be between 2 and 40."
    ]
  }, 
  "data": {
    "name": "H"
  }
}

Example 3: add unknown parameter to body: something = nothing

curl -i -X PUT -H "Content-Type: application/json" -d '{"name":"Hamburg", "something": "nothing"}' http://127.0.0.1:5000/api/v1/cities/2

The response:

HTTP/1.0 400 BAD REQUEST
Content-Type: application/json
Content-Length: 273
Server: Werkzeug/1.0.1 Python/3.8.5
Date: Mon, 29 Mar 2021 16:21:40 GMT

{
  "status": 400,
  "title": "One or more request body parameters did not validate", 
  "messages": {
    "something": [
      "Unknown field."
    ]
  }, 
  "data": {
    "name": "Hamburg", 
    "something": "nothing"
  },
  "valid_data": {
    "name": "Hamburg"
  }
}

Example 4: select unknown city, to show 404

curl -i -X PUT -H "Content-Type: application/json" -d '{"name":"Hamburg"}' http://127.0.0.1:5000/api/v1/cities/20

The response:

HTTP/1.0 404 NOT FOUND
Content-Type: application/json
Content-Length: 173
Server: Werkzeug/1.0.1 Python/3.8.5
Date: Tue, 30 Mar 2021 16:26:38 GMT

{
  "status": 404, 
  "title": "The requested resource could not be found", 
  "messages": {
    "City": [
      "Not found"
    ]
  }, 
  "data": {
    "city_id": 20
  }
}

Using schema pre-processing and post-processing

You may run into problems when using default (optional) values. Fortunately Marshmallow is flexible enough to let you do your own processing. You can extend your schema with the pre-processing and post-processing methods pre_load() and post_load().

For example, I extended the PaginationQueryParamsSchema to properly handle the pagination defaults by adding the pre_load() method:

class PaginationQueryParamsSchema(RequestQueryParamsSchema):
    def pre_load_(self, data, many, **kwargs):
		# your processing here
		...
        return data

    page = fields.Int(
		...

    per_page = fields.Int(
		...

Summary

Marshmallow not only makes it easy to use schemas when dumping objects but also includes an extensive set of schema validation functions that we can use to check request parameters. In this post I showed a way to process the request parameters of a RESTful API. Of course, there are other ways. Using the Flask request object we can pass request.view_args for the Url path parameters, request.args for the Url query parameters and request.form for the body parameters, to the schema.load() method.

Links / credits

marshmallow: simplified object serialization
https://marshmallow.readthedocs.io/en/stable/

Problem Details for HTTP APIs
https://tools.ietf.org/html/rfc7807

Read more

API Flask

Leave a comment

Comment anonymously or log in to comment.

Comments (1)

Leave a reply

Reply anonymously or log in to reply.

avatar

yo please change your font colors. its horrible to read the text/code.