import sys
import os

pred_file = sys.argv[1]
trip_file = sys.argv[2]
out_file = sys.argv[3]
threshold = sys.argv[4]

with open(pred_file, 'r') as fi:
	data = []
	for line in fi:
		l = line.strip().split()
		data += [(l[0], l[1], l[2], float(l[3]))]
	data = sorted(data, key=lambda x:x[3], reverse=True)

with open(trip_file, 'r') as fi:
	trip = set()
	for line in fi:
		l = line.strip().split()
		trip.add((l[0], l[1], l[2]))

if sys.argv[3].find('.') != -1:	
	threshold = float(threshold)
	for tp in data:
		if tp[3] < threshold:
			continue
		trip.add((tp[0], tp[1], tp[2]))
else:
	threshold = int(threshold)
	for k in range(min(threshold, len(data))):
		tp = data[k]
		trip.add((tp[0], tp[1], tp[2]))

with open(out_file, 'w') as fo:
	for h, r, t in trip:
		fo.write('{}\t{}\t{}\n'.format(h, r, t))
