import sys, time, tempfile

from bz2 import BZ2File as open_bz2

class Timer( object ):
    def __init__( self ):
        self.time = None
    def reset( self ):
        self.time = time.clock()
    def display( self ):
        print "%0.3f seconds" % ( time.clock() - self.time )
query = sys.argv[1]
target = sys.argv[2]

def read_bed( fname ):
    f = open_bz2( fname )
    for line in f:
        fields = line.rstrip( "\r\n" ).split( "\t" )
        yield int( fields[1] ), int( fields[2] )
    f.close()

print "Query intervals: ", query
print "Target intervals:", target

# Waste memory here to avoid variability in disk read speed
query_intervals = list( read_bed( query ) )
target_intervals = list( read_bed( target ) )

timer = Timer()







import sqlite3, bisect

def interval_size_code(low,high):
    """ Round the size of the interval down to the nearest
        power of two. """
    
    size = high-low
    approx_size = 1
    while approx_size*2 < size: approx_size *= 2 
    return approx_size

def interval_query(maximum_size,table,select_what):
    """ Construct a query to retrieve intervals intersecting the
        interval [:lower,:upper). """

    queries = [ ]
    size = 1
    while size <= maximum_size:
        queries.append(
            'select %s from %s where asize = %d and low < :upper and :lower-%d < low' % (select_what,table,size,size)) 
        queries.append(
            'select %s from %s where asize = %d and high < :upper+%d and :lower < high' 
            ' and not (low < :upper and :lower-%d < low)'% (select_what,table,size,size,size)) 
        size *= 2
    return ' union all '.join(queries)




print "---> Using pfh's method, pure python version"

print "Building:",
timer.reset()

bins = { }
for i, ( start, end ) in enumerate( target_intervals ):
     asize = interval_size_code(start,end)
     if asize not in bins: bins[asize] = ([],[])
     bins[asize][0].append( (start,i) )     
     bins[asize][1].append( (end,i) )     
for asize in bins:
     bins[asize][0].sort()
     bins[asize][1].sort()

timer.display()

print "Using:",
timer.reset()
total_retrieved = 0
for start, end in query_intervals:
    result = set()
    for asize in bins:        
        left = bisect.bisect_left(bins[asize][0],(start-asize+1,))
        right = bisect.bisect_left(bins[asize][0],(end,))
        for item in bins[asize][0][left:right]:
            result.add(item[1])
            
        left = bisect.bisect_left(bins[asize][1],(start+1,))
        right = bisect.bisect_left(bins[asize][1],(end+asize,))
        for item in bins[asize][1][left:right]:
            result.add(item[1])
    total_retrieved += len(result)
timer.display()
#print 'Retrieved %d intervals in total' % total_retrieved



print "---> Using pfh's method, SQLite version"

print "Building:",
timer.reset()

conn = sqlite3.connect(':memory:')
conn.executescript("""
    create table intervals (low,high,asize,i);
    create index index1 on intervals(asize,low);
    create index index2 on intervals(asize,high);
""")
max_size = 0
for i, ( start, end ) in enumerate( target_intervals ):
     asize = interval_size_code(start,end)
     max_size = max(max_size,asize)
     conn.execute('insert into intervals values (?,?,?,?)',(start,end,asize,i))
conn.commit()

timer.display()


print "Using:",
timer.reset()
query = interval_query(max_size, 'intervals', 'i')
total_retrieved = 0
for start, end in query_intervals:
    result = conn.execute(query,dict(lower=start,upper=end)).fetchall()
    total_retrieved += len(result)
timer.display()
#print 'Retrieved %d intervals in total' % total_retrieved


if 0:
    conn.execute('create index index3 on intervals(low,high)')
    conn.commit()
    print
    print 'Brute force retrieval:'
    timer.reset()
    query = 'select i from intervals where low < :upper and :lower < high'
    total_retrieved = 0
    for start, end in query_intervals:
        result = conn.execute(query,dict(lower=start,upper=end)).fetchall()
        total_retrieved += len(result)
    timer.display()
    print 'Retrieved %d intervals in total' % total_retrieved




