from pylab import figure, plot, setp
from numpy import array
from itertools import count
from math import sqrt, floor

def gcd(a, b):
    """Find the gcd of a, b."""
    while a != b:
        if a > b:
            a = a - b
        else:
            b = b - a
    return a

def sum_squares(n):
    """Sum every possible square of the integers in range(1,n).
    Returns the composite Pythagorean triples and the primitive 
    Pythagrean triples seperately."""
    composites = []
    primitives = []
    for y in range(1,n):
        for x in range(1,n):
            if x < y:
                continue
            z = sqrt(x**2 + y**2)
            if sqrt(x**2 + y**2) - floor(sqrt(x**2 + y**2)) == 0:
                if x % 4 != 0 and y % 4 != 0 and z % 4 != 0:
                    # if we ever find a triple in which none of the
                    # elements are divisible by 4 throw an error.
                    print "%5i^2 + %5i^2 = %5i^2" % (x,y, sqrt(x**2+y**2))
                    raise "Found a pythagorean triple with no element"\
                        "divisible by four!"
                if gcd(x,y) == 1 and gcd(y,z) == 1 and gcd(x,z) == 1:
                    # this triple is primitive
                    primitives.append((x,y,z))
                else:
                    composites.append((x,y,z))
                #print "%5i^2 + %5i^2 = %5i^2" % (x,y, sqrt(x**2+y**2))
    return array(composites), array(primitives)

def gen_triples(n):
    """Generate the pythagorean triples using the equation
    (2pq, p**2 - q**2, p**2 + q**2) while 2pq and p**2 - q**2 are
    both less than n."""
    pairs = []
    x,y,z = 0,0,0
    for p in count(1):
        for q in count(1):
            if q >= p:
                break
            x = 2 * p * q
            y = p**2 - q**2
            z = p**2 + q**2
            if x < y:
                # swap x for y if x < y.  This puts all the (x,y)'s
                # below the line y = x.
                x,y = y,x
            if x < n and y < n:
                pairs.append((x,y,z))
            #print " %5i^2 + %5i^2 = %5i^2" % (x,y,z)
        if x >= n and y >= n:
            break
    return array(pairs)

MAX = 200
SIZE = 4.0

fig = figure()

composites, primitives = sum_squares(MAX)
pc = plot(composites[:,0], composites[:,1], 'o')
setp(pc, markersize=SIZE, markeredgewidth=0, markerfacecolor="#2222FF")
pp = plot(primitives[:,0], primitives[:,1], 'o')
setp(pp, markersize=SIZE, markeredgewidth=0, markerfacecolor="#22FF22")

# generate pythagorean triples
triples = gen_triples(MAX)
pt = plot(triples[:,0], triples[:,1], 'o')
setp(pt, markersize=SIZE - SIZE/5.0, markeredgewidth=0, markerfacecolor="#FFFFFF")


fig.savefig("test.png", dpi=184)
