Assume there are two DataFrames (1st one represents the transitional probabilities and 2nd one stores possible permutation sequences of 3 items: A, B, C). These 3 items have various replicates (A has 3, B has 2, and C has 2). Each item worth differently (A = 10, B = 5, C = 1). The goal is to compute total amount of each permutation sequence and append this amount information into 2nd DataFrame. Here, amount = item value in col0 + sum of the item value * a corresponding transitional probability (e.g., A -> B is 0.4, But B -> A is 0.2)
1st DataFrame (TP)
import pandas as pd
import numpy as np
from itertools import permutations
probs = np.asarray([ [0.3,0.4,0.3],
[0.2,0.3,0.5],
[0.6,0.1,0.3]])
prob_df = pd.DataFrame(probs, index = ['A', 'B', 'C'])
prob_df.columns = ['A', 'B', 'C']
prob_df
A B C
A 0.3 0.4 0.3
B 0.2 0.3 0.5
C 0.6 0.1 0.3
2nd DataFrame
items = ['A','A','A','B','B','C','C']
perms = permutations(items)
output = [i for i in perms]
perm_df = pd.DataFrame(output).drop_duplicates()
print(perm_df.shape)
perm_df.head()
0 1 2 3 4 5 6
0 A A A B B C C
2 A A A B C B C
3 A A A B C C B
12 A A A C B B C
13 A A A C B C B
The amount of each permutation seq is calculated as follow:
for row0: (A, A, A, B, B, C, C)
the amount = 10 (col0) + 10 * 0.3 (col0->col1) + 10 * 0.3 (col1->col2) + 5 * 0.4 (col2->col3) + 5 * 0.3 (col3->col4) + 1 * 0.5 (col4->col5) + 1 * 0.3 (col5->col6) = 20.3
for row12: (A, A, A, C, B, B, C)
the amount = 10 + 10 * 0.3 + 10 * 0.3 + 1 * 0.3 + 5 * 0.1 + 5 * 0.3 + 1 * 0.5 = 18.8
any idea? many thanks in advance!
>Solution :
I would use shift+concat+stack to create a long DataFrame of source/target pairs, then merge with the reshaped weights, reshape back with unstack, multiply by the weights and sum:
weights = [10, 10, 10, 5, 5, 1, 1]
perm_df['out'] = (pd
.concat({'source': perm_df.shift(axis=1),
'target': perm_df}, axis=1)
.stack()
.merge(prob_df.stack().rename('prob'),
left_on=['source', 'target'],
right_index=True, how='left')
['prob'].fillna(1).unstack()
.mul(weights).sum(axis=1)
)
Output:
0 1 2 3 4 5 6 out
0 A A A B B C C 20.3
2 A A A B C B C 21.1
3 A A A B C C B 20.9
12 A A A C B B C 18.8
13 A A A C B C B 18.6
... .. .. .. .. .. .. .. ...
4216 C C A B B A A 23.0
4272 C C B A A A B 17.2
4273 C C B A A B A 17.1
4276 C C B A B A A 17.5
4290 C C B B A A A 17.1
[210 rows x 8 columns]