scikit-learnモデルをFlaskでAPI化する

こちらの記事のようにscikit-learnで作ったモデルを、pickleでファイルに保存する。

import pickle
with open('model.pickle', mode='wb') as fp:
    pickle.dump(model, fp)

pickleで保存したファイルは、下記のようにするとモデルとして復元され、予測処理を行うことができる。

import pickle
with open('model.pickle', mode='rb') as fp:
    model = pickle.load(fp)
model.predict([[30]])

このモデルをFlaskを使ってAPI化する。

まず、flaskをインストールする。

pip install flask

FlaskはPythonのWebフレームワークだが、マイクロフレームワークを標榜しているため、ファイルを1つだけ作れば充分に動作する。
下記のようなファイルをapp.pyという名前で作成する。

from flask import Flask, request, Response
import pickle
import json

app = Flask(__name__)


@app.route('/predict', methods=['POST'])
def predict():
    json_data = request.get_json()

    with open('model.pickle', mode='rb') as fp:
        model = pickle.load(fp)
    pred = model.predict([[json_data['value']]])

    result = {
        "result": pred[0][0]
    }

    return Response(json.dumps(result))

あとは、下記のコマンドでFlaskの開発用サーバを立ち上げる。

flask run --reload

reload引数を付けておけば、ソースコードを書き換えて保存すれば自動的に反映される。

開発用サーバが起動すると、

Running on http://127.0.0.1:5000/

と表示される。デフォルトで5000番ポートを使用する。
Ctrl+Cで開発用サーバが終了する。

開発用サーバの起動中に、例えば下記のようなCurlコマンドを実行し、APIコールしてみる。

curl --request POST --url http://localhost:5000/predict --header 'content-type: application/json' --data '{"value": 30}'

VS Codeを使っているなら、REST Clientという拡張機能をインストールして、拡張子がhttpのファイルを作成し、下記のような内容を入力して実行すると良い。

POST http://localhost:5000/predict
Content-Type: application/json

{
    "value": 30
}

下記のような実行結果が表示される。(VS Code+REST Clientの場合)

HTTP/1.0 200 OK
Content-Type: text/html; charset=utf-8
Content-Length: 30
Server: Werkzeug/1.0.1 Python/3.8.5
Date: Thu, 07 Oct 2021 07:08:01 GMT

{
  "result": 18.002818181818185
}