Ahora veamos como podemos exponer un modelo previamente entrenado para clasificar imágenes desde un API de forma sencilla usando Flask
Continuando con el ejemplo anterior donde entrenamos un modelo en Pytorch para reconocer imágenes, ahora toca exponer un API que permita a los clientes enviar las fotos para predecir la probabilidad que pertenezcan determinada clase
Flask
Flask es un micro framework para manejar peticiones HTTP, nos servirá para construir el API de forma sencilla
Aplicación de Flask
Primero, vamos a crear una aplicación de flask y establecer una ruta de pruebas:
app = flask.Flask("dog-feeder-api")
@app.route('/', methods=['GET'])
def home():
return jsonify({"message": "probably try /dogs instead"})
if __name__ == '__main__':
app.run(debug=True)
Como nuestro código tendrá muchas suposiciones, lo cual significa que es propenso a romperse, así que es conveniente tener una función que nos ayude a lidiar con los errores que ocurran en el servidor
@app.errorhandler(Exception)
def handle_500(e):
"""
Handles all the errors from the server and
applys a JSON format for the response
"""
return jsonify({"error": True, "message": str(e)}), 500
API con flask restful
El API puede ser creado usando request handlers similares a los que definimos previamente. Sin embargo, flask cuenta con un módulo llamado flask restful, el cual nos permite definir recursos, similares a routers/controllers que nos ayudará con la separación de contextos
TL;DR
Lo único que necesitamos asimilar de flask restful es que, podemos extender un componente Resource
y definir los verbos HTTP que vamos a soportar en ese recurso. Para este caso, utilizaremos sólo POST
y GET
para realizar las predicciones y obtener las clases respectivamente
Entonces, nuestro código base para el manejador del recurso iniciará de la siguiente forma:
class DogAPI(Resource):
def get(self):
return classes
def post(self):
full_path = self.save_picture_to_file(request.files["picture"])
prediction = self.predict(full_path)
return {"predictions": prediction}
También, debemos indicar a nuestra aplicación de Flask que vamos a utilizar este recurso:
api = Api(app)
api.add_resource(DogAPI, '/dogs')
Pytorch
cargar el modelo
Para la carga de nuestro modelo, en realidad vamos a reutilizar gran parte del codigo del capitulo anterior. Las dos operaciones que tenemos que realizar son cargar el modelo y definir el transformador de imágenes que vamos a utilizar:
def load_model():
device = "cpu"
model = models.resnet50(pretrained=False).to(device)
model.fc = nn.Sequential(
nn.Linear(2048, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 2)).to(device)
model.load_state_dict(torch.load('./src/model_meta/dog-trainer.h5'))
return (model_classes, model)
def image_transformer():
normalizer = transforms.Normalize(
[0.5, 0.5, 0.5],
[0.5, 0.5, 0.5]
)
transformer = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalizer
])
return transformer
clasificación
Para llevar a cabo esta operación, es necesario implementar el método predict
que referenciamos en DogAPI#post
:
def predict(self, img):
validation_batch = torch.stack(
[transformer(Image.open(img)).to(device)])
prediction_tensor = model(validation_batch)
# transform the predictions to a probabilistic value
prediction_probabilistic = F.softmax(
prediction_tensor, dim=1).cpu().data.numpy()
return {
"sitting": str(prediction_probabilistic[0, 0]),
"standing": str(prediction_probabilistic[0, 1])
}
Lo único que nos hace falta aquí es asegurarnos que img
es el path de una imagen persistida (y accesible) en un dispositivo, para esto, tenemos que implementar DogAPI#save_picture_to_file
:
def save_picture_to_file(self, picture):
extension = os.path.splitext(picture.filename)[1]
file_name = str(uuid()) + extension
full_path = os.path.join(PREDICTION_DIR, file_name)
picture.save(full_path)
return full_path
pruebas
Tiempo de probar nuestro API:
# inicializar el API
python3 ./src/main.py
# probar el API
curl -i \
-X POST \
-H "Content-Type: multipart/form-data" \
-F "picture=@sitting.jpg" \
http://localhost:5000/dogs

conclusión
Es sencillo leer un modelo de Pytorch persistido para la clasificación de imágenes recibidas desde un API REST, todo esto, con un código muy similar a la prueba de clasificación que hicimos cuando recién entrenamos el modelo
Repositorio: https://github.com/carlosJoseloMtz/pytorch-model-flask-api