Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grouping in warp #498

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
110 changes: 100 additions & 10 deletions lightfm/_lightfm_fast.pyx.template
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ from libc.stdlib cimport free, malloc

{openmp_import}


ctypedef float flt

# Allow sequential code blocks in a parallel setting.
Expand Down Expand Up @@ -81,7 +80,8 @@ cdef int rand_r(unsigned int * seed) nogil:
return temper(seed[0]) / 2


cdef int sample_range(int min_val, int max_val, unsigned int *seed) nogil:
cdef int sample_range(int min_val, int max_val,
unsigned int *seed) nogil:

cdef int val_range

Expand All @@ -90,6 +90,17 @@ cdef int sample_range(int min_val, int max_val, unsigned int *seed) nogil:
return min_val + (rand_r(seed) % val_range)


cdef int *shuffle_rows(int *sh_rows, int rows, unsigned int *seed) nogil:
cdef int a, j, tmp
for a in range(rows):
sh_rows[a] = a
for a in range(rows - 1):
j = sample_range(a, rows, seed)
tmp = sh_rows[j]
sh_rows[j] = sh_rows[a]
sh_rows[a] = tmp


cdef int int_min(int x, int y) nogil:

if x < y:
Expand All @@ -110,6 +121,10 @@ cdef struct Pair:
int idx
flt val

cdef struct GroupData:
int size
int offset


cdef int reverse_pair_compare(const_void *a, const_void *b) nogil:

Expand Down Expand Up @@ -284,6 +299,39 @@ cdef inline int in_positives(int item_id, int user_id, CSRMatrix interactions) n
return 1


cdef inline int* gen_item_group_map(int *item_group_map,
CSRMatrix item_groups,
unsigned int *seed) nogil:
cdef int group_row_start, group_row_end, item_id, r, i, sr
cdef int *randomised_rows

randomised_rows = <int *>malloc(sizeof(int) * item_groups.rows)

shuffle_rows(randomised_rows, item_groups.rows, seed)

for r in {range_block}(item_groups.rows):
sr = randomised_rows[r]
group_row_start = item_groups.get_row_start(sr)
group_row_end = item_groups.get_row_end(sr)
for i in {range_block}(group_row_end - group_row_start):
item_id = item_groups.indices[group_row_start + i]
item_group_map[item_id] = sr

free(randomised_rows)

return item_group_map


cdef inline GroupData *get_group_size(GroupData *g, int group, CSRMatrix item_groups) nogil:
cdef int group_row_start, group_row_end

group_row_start = item_groups.get_row_start(group)
group_row_end = item_groups.get_row_end(group)
# get size and offset of group
g.size = group_row_end - group_row_start
g.offset = group_row_start


cdef inline void compute_representation(CSRMatrix features,
flt[:, ::1] feature_embeddings,
flt[::1] feature_biases,
Expand Down Expand Up @@ -608,7 +656,9 @@ cdef void warp_update(double loss,
avg_learning_rate += update_features(item_features, lightfm.item_features,
lightfm.item_feature_gradients,
lightfm.item_feature_momentum,
i, positive_item_start_index, positive_item_stop_index,
i,
positive_item_start_index,
positive_item_stop_index,
-loss * user_component,
lightfm.adadelta,
lightfm.learning_rate,
Expand All @@ -618,7 +668,9 @@ cdef void warp_update(double loss,
avg_learning_rate += update_features(item_features, lightfm.item_features,
lightfm.item_feature_gradients,
lightfm.item_feature_momentum,
i, negative_item_start_index, negative_item_stop_index,
i,
negative_item_start_index,
negative_item_stop_index,
loss * user_component,
lightfm.adadelta,
lightfm.learning_rate,
Expand All @@ -629,17 +681,17 @@ cdef void warp_update(double loss,
lightfm.user_feature_gradients,
lightfm.user_feature_momentum,
i, user_start_index, user_stop_index,
loss * (negative_item_component -
positive_item_component),
loss * (negative_item_component
- positive_item_component),
lightfm.adadelta,
lightfm.learning_rate,
user_alpha,
lightfm.rho,
lightfm.eps)

avg_learning_rate /= ((lightfm.no_components + 1) * (user_stop_index - user_start_index)
+ (lightfm.no_components + 1) *
(positive_item_stop_index - positive_item_start_index)
+ (lightfm.no_components + 1)
* (positive_item_stop_index - positive_item_start_index)
+ (lightfm.no_components + 1)
* (negative_item_stop_index - negative_item_start_index))

Expand Down Expand Up @@ -784,6 +836,7 @@ def fit_logistic(CSRMatrix item_features,
def fit_warp(CSRMatrix item_features,
CSRMatrix user_features,
CSRMatrix interactions,
CSRMatrix item_groups,
int[::1] user_ids,
int[::1] item_ids,
flt[::1] Y,
Expand All @@ -801,12 +854,16 @@ def fit_warp(CSRMatrix item_features,

cdef int i, no_examples, user_id, positive_item_id, gamma
cdef int negative_item_id, sampled, row
cdef int group, rand_n
cdef double positive_prediction, negative_prediction
cdef double loss, MAX_LOSS
cdef GroupData *group_data
cdef bint use_groups
cdef flt weight
cdef flt *user_repr
cdef flt *pos_it_repr
cdef flt *neg_it_repr
cdef int *item_group_map
cdef unsigned int[::1] random_states

random_states = random_state.randint(0,
Expand All @@ -816,8 +873,15 @@ def fit_warp(CSRMatrix item_features,
no_examples = Y.shape[0]
MAX_LOSS = 10.0

use_groups = item_groups.nnz > 0

{nogil_block}

if use_groups:
group_data = <GroupData *>malloc(sizeof(GroupData) * no_examples)
item_group_map = <int *>malloc(sizeof(int) * item_groups.cols)
gen_item_group_map(item_group_map, item_groups, &random_states[0])

user_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))
pos_it_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))
neg_it_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))
Expand All @@ -828,6 +892,9 @@ def fit_warp(CSRMatrix item_features,
user_id = user_ids[row]
positive_item_id = item_ids[row]

if use_groups:
group = item_group_map[positive_item_id]
get_group_size(&group_data[i], group, item_groups)
if not Y[row] > 0:
continue

Expand Down Expand Up @@ -857,8 +924,13 @@ def fit_warp(CSRMatrix item_features,
while sampled < lightfm.max_sampled:

sampled = sampled + 1
negative_item_id = (rand_r(&random_states[{thread_num}])
% item_features.rows)
if use_groups:
rand_n = (rand_r(&random_states[{thread_num}])
% group_data[i].size)
negative_item_id = item_groups.indices[group_data[i].offset + rand_n]
else:
negative_item_id = (rand_r(&random_states[{thread_num}])
% item_features.rows)

compute_representation(item_features,
lightfm.item_features,
Expand Down Expand Up @@ -906,6 +978,9 @@ def fit_warp(CSRMatrix item_features,
free(user_repr)
free(pos_it_repr)
free(neg_it_repr)
if use_groups:
free(item_group_map)
free(group_data)

regularize(lightfm,
item_alpha,
Expand Down Expand Up @@ -1383,3 +1458,18 @@ def __test_in_positives(int row, int col, CSRMatrix mat):
return True
else:
return False


def test_item_group_map(int n, CSRMatrix mat, random_state):
cdef unsigned int[::1] random_states
cdef int *item_group_map

random_states = random_state.randint(0,
np.iinfo(np.int32).max,
size=1).astype(np.uint32)
item_group_map = <int *>malloc(sizeof(int) * mat.cols)

gen_item_group_map(item_group_map, mat, &random_states[0])
val = item_group_map[n]
free(item_group_map)
return val