Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Pandas shift that takes into account groups

I have chronological data (monthly aggregation per customer).

df=pd.DataFrame({'cust_id': [1,1,1,1,1,1,2,2,2,2,2],
                 'period' : [200010,200011,200012,200101,200102,200103,200010,200011,200012,200101,200103],
                 'volume' : [1,2,3,4,5,6,7,8,9,10,12],
                 'num_transactions': [3,4,5,6,7,8,9,10,11,12,13],
                 'label': [1,1,1,0,1,1,0,0,0,0,0]})

The dataframe is sorted out by user and month, ascending.

There is a column "label" which is, essentially, a categorical variable.

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

I want to introduce a column "next_month_label" where I store the label value for the next month for that user.

I used shift and then I realised that it does not consider the fact that the data for customer1 is then followed by that of customer2. So, essentially, the last row for customer1 is "borrowing" the label of the first row of customer2. Instead, the field "next_month_label" for the last row of customer1 should stay empty / null.

How to do that?

The expected result should look like this:

df=pd.DataFrame({'cust_id': [1,1,1,1,1,1,2,2,2,2,2],
                 'period' : [200010,200011,200012,200101,200102,200103,200010,200011,200012,200101,200103],
                 'volume' : [1,2,3,4,5,6,7,8,9,10,12],
                 'num_transactions': [3,4,5,6,7,8,9,10,11,12,13],
                 'label': [1,1,1,0,1,1,0,0,0,0,0],
                 'next_month_label': [1,1,0,1,1,NaN,0,0,0,0,NaN],
})

>Solution :

Let me know if this code gives you required result:

df=pd.DataFrame({'cust_id': [1,1,1,1,1,1,2,2,2,2,2],
                 'period' : [200010,200011,200012,200101,200102,200103,200010,200011,200012,200101,200103],
                 'volume' : [1,2,3,4,5,6,7,8,9,10,12],
                 'num_transactions': [3,4,5,6,7,8,9,10,11,12,13],
                 'label': [1,1,1,0,1,1,0,0,0,0,0]})

df['next_month_label'] = df.groupby('cust_id')['label'].shift(-1)

print(df)

 cust_id  period  volume  num_transactions  label  next_month_label
0         1  200010       1                 3      1               1.0
1         1  200011       2                 4      1               1.0
2         1  200012       3                 5      1               0.0
3         1  200101       4                 6      0               1.0
4         1  200102       5                 7      1               1.0
5         1  200103       6                 8      1               NaN
6         2  200010       7                 9      0               0.0
7         2  200011       8                10      0               0.0
8         2  200012       9                11      0               0.0
9         2  200101      10                12      0               0.0
10        2  200103      12                13      0               NaN
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading