Pass arguments to function from BranchPythonOperator in Airflow

I am running below code to create DAG. Dags are created but choose_best_model Dag is failing. Error is: ERROR - _choose_best_model() missing 1 required positional argument: 'ti'. My Airflow version is: 1.10.3. How Can I resolve this error?

my_dag.py

from airflow import DAG
from airflow.operators.python_operator import PythonOperator, BranchPythonOperator
from airflow.operators.bash_operator import BashOperator

from random import randint
from datetime import datetime

def _choose_best_model(ti):
    accuracies = ti.xcom_pull(task_ids=[
        'training_model_A',
        'training_model_B',
        'training_model_C'
    ])
    best_accuracy = max(accuracies)
    if (best_accuracy > 8):
        return 'accurate'
    return 'inaccurate'


def _training_model():
    return randint(1, 10)

with DAG("my_dag", start_date=datetime(2021, 1, 1),
    schedule_interval="@daily", catchup=False) as dag:

        training_model_A = PythonOperator(
            task_id="training_model_A",
            python_callable=_training_model
        )

        training_model_B = PythonOperator(
            task_id="training_model_B",
            python_callable=_training_model
        )

        training_model_C = PythonOperator(
            task_id="training_model_C",
            python_callable=_training_model
        )

        choose_best_model = BranchPythonOperator(
            task_id="choose_best_model",
            python_callable=_choose_best_model
        )

        accurate = BashOperator(
            task_id="accurate",
            bash_command="echo 'accurate'"
        )

        inaccurate = BashOperator(
            task_id="inaccurate",
            bash_command="echo 'inaccurate'"
        )

        [training_model_A, training_model_B, training_model_C] >> choose_best_model >> [accurate, inaccurate]

>Solution :

You need to pass the provide_context parameter to your operator (it’s extending the PythonOperator which defines it). You also need to add the kwargs to your function’s signature.

The full list of parameters in the context which can be passed to your python_callable can be found here (v.1.10.15).

Once you do this, you can also pass additional custom parameters to your function using the op_kwargs parameter.

PythonOperator Airflow docs

[...]    
def _choose_best_model(ti, **kwargs):   # <-- here 
    accuracies = ti.xcom_pull(task_ids=[
        'training_model_A',
        'training_model_B',
        'training_model_C'
    ])
    [...]    
    
with DAG("my_dag", start_date=datetime(2021, 1, 1),
    schedule_interval="@daily", catchup=False) as dag:

        [...]    

        choose_best_model = BranchPythonOperator(
            task_id="choose_best_model",
            python_callable=_choose_best_model, 
            provide_context=True,   # <-- here 
        )

        [...]    

Leave a Reply