Custom model#
mlflow can be used to log custom models. Here I try to understand how it works.
Sources:
Set up#
The following cell starts the mlflow instance in docker.
%%bash
docker run -p 5000:5000 -dt --name my_server --rm \
ghcr.io/mlflow/mlflow \
bash -c "mlflow server --host 0.0.0.0 --port 5000"
f36c038cbd239d43e24220ee09150647ab5c03c7701c4040487e15e716f7e24a
Note Don’t forget to close the container when you’ve finished playing with this notebook:
!docker stop my_server
my_server
Now imports:
import mlflow
import pandas as pd
mlflow.set_tracking_uri(uri="http://localhost:5000")
Function#
The simplest method is to define a function that takes model input and returns model output. You need to pass this function as the python_moder
argument of the mlflow.pyfunc.log_model
function:
def predict(model_input):
return model_input.apply(lambda x: x * 2)
with mlflow.start_run():
mlflow.pyfunc.log_model(
"model",
python_model=predict,
pip_requirements=["pandas"])
run_id = mlflow.active_run().info.run_id
Now you can load this model from mlflow. But it’s interesting that you don’t have the function stored directly, but some wrapper mlflow.pyfunc.PyFuncModel
following the cell to show it:
# Load the model from the tracking server and perform inference
model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model")
print("Loaded object type -", type(model))
Loaded object type - <class 'mlflow.pyfunc.PyFuncModel'>
But it does have a predict
method that you can use to get a prediction:
model.predict(pd.Series([10,20,30]))
0 20
1 40
2 60
dtype: int64
Class#
You can use object ancestors of mlflow.pyfunc.PythonModel
that implement the predict
method. This method should have context
and model_imput
parameters.
The following cell shows what this might look like:
class MyModel(mlflow.pyfunc.PythonModel):
test = "I'm field you want to acess!!!"
def predict(self, context, model_input, params=None):
return [x * 2 for x in model_input]
# Save the function as a model
with mlflow.start_run():
mlflow.pyfunc.log_model(
"model", python_model=MyModel(), pip_requirements=["pandas"]
)
run_id = mlflow.active_run().info.run_id
But after loading, you’ll have a mlflow.pyfunc.PyFuncModel
. It’s not type of the instance we logged with mlflow.
# Load the model from the tracking server and perform inference
model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model")
print("Loaded object type -", type(model))
Loaded object type - <class 'mlflow.pyfunc.PyFuncModel'>
To extract the prediction from these objects, simply use the predict
method.
print(f"Prediction:{model.predict(pd.Series([1, 2, 3]))}")
Prediction:[2, 4, 6]
You extract the original object from the wrapper using the unwrap_python_model
method. The following cell access attribute of the original object is logged with mlflow.
model.unwrap_python_model().test
"I'm field you want to acess!!!"