Exponer modelo entrenado de Pytorch en un API con Flask

Ahora veamos como podemos exponer un modelo previamente entrenado para clasificar imágenes desde un API de forma sencilla usando Flask

Continuando con el ejemplo anterior donde entrenamos un modelo en Pytorch para reconocer imágenes, ahora toca exponer un API que permita a los clientes enviar las fotos para predecir la probabilidad que pertenezcan determinada clase

Flask

Flask es un micro framework para manejar peticiones HTTP, nos servirá para construir el API de forma sencilla

Aplicación de Flask

Primero, vamos a crear una aplicación de flask y establecer una ruta de pruebas:

app = flask.Flask("dog-feeder-api")


@app.route('/', methods=['GET'])
def home():
    return jsonify({"message": "probably try /dogs instead"})


if __name__ == '__main__':
    app.run(debug=True)
Aplicación base de Flask

Como nuestro código tendrá muchas suposiciones, lo cual significa que es propenso a romperse, así que es conveniente tener una función que nos ayude a lidiar con los errores que ocurran en el servidor

@app.errorhandler(Exception)
def handle_500(e):
    """
    Handles all the errors from the server and
    applys a JSON format for the response
    """
    return jsonify({"error": True, "message": str(e)}), 500
500 error handler

API con flask restful

El API puede ser creado usando request handlers similares a los que definimos previamente. Sin embargo, flask cuenta con un módulo llamado flask restful, el cual nos permite definir recursos, similares a routers/controllers que nos ayudará con la separación de contextos

TL;DR

Lo único que necesitamos asimilar de flask restful es que, podemos extender un componente Resource y definir los verbos HTTP que vamos a soportar en ese recurso. Para este caso, utilizaremos sólo POST y GET para realizar las predicciones y obtener las clases respectivamente

Entonces, nuestro código base para el manejador del recurso iniciará de la siguiente forma:

class DogAPI(Resource):

    def get(self):
        return classes

    def post(self):
        full_path = self.save_picture_to_file(request.files["picture"])
        prediction = self.predict(full_path)
        return {"predictions": prediction}
DogAPI Rest handler

También, debemos indicar a nuestra aplicación de Flask que vamos a utilizar este recurso:

api = Api(app)

api.add_resource(DogAPI, '/dogs')
Aplicación base de Flask

Pytorch

cargar el modelo

Para la carga de nuestro modelo, en realidad vamos a reutilizar gran parte del codigo del capitulo anterior. Las dos operaciones que tenemos que realizar son cargar el modelo y definir el transformador de imágenes que vamos a utilizar:

def load_model():
    device = "cpu"

    model = models.resnet50(pretrained=False).to(device)

    model.fc = nn.Sequential(
        nn.Linear(2048, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, 2)).to(device)

    model.load_state_dict(torch.load('./src/model_meta/dog-trainer.h5'))

    return (model_classes, model)


def image_transformer():
    normalizer = transforms.Normalize(
        [0.5, 0.5, 0.5],
        [0.5, 0.5, 0.5]
    )

    transformer = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        normalizer
    ])

    return transformer
Lectura del modelo existente

clasificación

Para llevar a cabo esta operación, es necesario implementar el método predict que referenciamos en DogAPI#post:

    def predict(self, img):
        validation_batch = torch.stack(
            [transformer(Image.open(img)).to(device)])

        prediction_tensor = model(validation_batch)

        # transform the predictions to a probabilistic value
        prediction_probabilistic = F.softmax(
            prediction_tensor, dim=1).cpu().data.numpy()

        return {
            "sitting": str(prediction_probabilistic[0, 0]),
            "standing": str(prediction_probabilistic[0, 1])
        }
DogAPI#predict

Lo único que nos hace falta aquí es asegurarnos que img es el path de una imagen persistida (y accesible) en un dispositivo, para esto, tenemos que implementar DogAPI#save_picture_to_file:

    def save_picture_to_file(self, picture):
        extension = os.path.splitext(picture.filename)[1]
        file_name = str(uuid()) + extension
        full_path = os.path.join(PREDICTION_DIR, file_name)
        picture.save(full_path)
        return full_path

pruebas

Tiempo de probar nuestro API:

# inicializar el API
python3 ./src/main.py

# probar el API
curl -i \
  -X POST \
  -H "Content-Type: multipart/form-data" \
  -F "picture=@sitting.jpg" \
  http://localhost:5000/dogs
Probar nuestro API
Resultados de la predicción

conclusión

Es sencillo leer un modelo de Pytorch persistido para la clasificación de imágenes recibidas desde un API REST, todo esto, con un código muy similar a la prueba de clasificación que hicimos cuando recién entrenamos el modelo

Repositorio: https://github.com/carlosJoseloMtz/pytorch-model-flask-api

Carlos Jose Martinez Arenas

Carlos Jose Martinez Arenas

Enamorado de la tecnología, con interés en machine learning, código limpio, buenas prácticas y arquitectura de sistemas; música y cine; naturaleza, espacios abiertos y los perritos