SIMD True Branching with "movemask"

CS 441 Lecture, Dr. Lawlor




Detailed example: movemask and SSE Branching

Let's start with the Mandelbrot set fractal:
#include <iostream>
#include <fstream> /* for ofstream */
#include <complex> /* for fractal arithmetic */

/**
A linear function in 2 dimensions: returns a double as a function of (x,y).
*/
class linear2d_function {
public:
double a,b,c;
void set(double a_,double b_,double c_) {a=a_;b=b_;c=c_;}
linear2d_function(double a_,double b_,double c_) {set(a_,b_,c_);}
double evaluate(double x,double y) const {return x*a+y*b+c;}
};

int foo(void)
{
// Figure out how big an image we should render:
int wid=350, ht=256;

// Create a PPM output image file header:
std::ofstream out("out.ppm",std::ios_base::binary);
out<<"P6\n"
<<wid<<" "<<ht<<"\n"
<<"255\n";

// Set up coordinate system to render the Mandelbrot Set:
double scale=3.0/wid;
linear2d_function fx(scale,0.0,-scale*wid/2); // returns c given pixels
linear2d_function fy(0.0,scale,-1.0);

for (int y=0;y<ht;y++)
for (int x=0;x<wid;x++) {
/* Walk this Mandelbrot Set pixel */
typedef std::complex<double> COMPLEX;
COMPLEX c(fx.evaluate(x,y),fy.evaluate(x,y));
COMPLEX z(0.0);
int count;
enum {max_count=100};
for (count=0;count<max_count;count++) {
z=z*z+c;
if ((z.real()*z.real()+z.imag()*z.imag())>=4.0) break;
}

/* Figure out the output pixel color */
unsigned char r,g,b;
r=(unsigned char)(z.real()*(256/2.0));
g=(unsigned char)(z.imag()*(256/2.0));
b=(unsigned char)(((z.real()*z.real()+z.imag()*z.imag()))*256);
out<<r<<g<<b;
}

return 0;
}

(Try this in NetRun now!)

This takes 57.8ms to render a 350x256 PPM image.  The first thing to note is that this line is sequential:
		out<<r<<g<<b;
We can't write to the output file in parallel, so we'll have to collect the output values into an array of "pixel" objects first:
class pixel { /* onscreen RGB pixel */
public:
unsigned char r,g,b;
pixel() {}
pixel(const COMPLEX &z) {
r=(unsigned char)(z.real()*(256/2.0));
g=(unsigned char)(z.imag()*(256/2.0));
b=(unsigned char)(((z.real()*z.real()+z.imag()*z.imag()))*256);
}
};

int foo(void)
{
...
pixel *output=new pixel[wid*ht];

for (int y=0;y<ht;y++)
for (int x=0;x<wid;x++) {
...

output[x+y*wid]=pixel(z);
}
out.write((char *)&output[0],sizeof(pixel)*wid*ht);

return 0;
}

(Try this in NetRun now!)

Doing the output in one big block is nearly twice as fast--just 29.8ms.  Man!  Even thinking about running in parallel sped up this program substantially!  I think this is mostly due to the fact that the inner loop now has no function calls, so the compiler doesn't have to save and restore registers across the I/O function call.

Mandelbrot SSE

Let's try to use SSE.  Here's the inner loop:
	COMPLEX c(fx.evaluate(x,y),fy.evaluate(x,y));
COMPLEX z(0.0);
enum {max_count=100};
for (count=0;count<max_count;count++) {
z=z*z+c;
if ((z.real()*z.real()+z.imag()*z.imag())>=4.0) break;
}
output[x+y*wid]=pixel(z);
That "z*z" line is complex multiplication, so it's really "z=COMPLEX(z.r*z.r-z.i*z.i,2*z.r*z.i);".   Let's first rewrite this in terms of basic floating-point operations:
	COMPLEX c(fx.evaluate(x,y),fy.evaluate(x,y));
float cr=c.real(), ci=c.imag();
float zr=0.0, zi=0.0;
int count;
enum {max_count=100};
for (count=0;count<max_count;count++) {
float tzr=zr*zr-zi*zi+cr; /*subtle: don't overwrite zr yet!*/
float tzi=2*zr*zi+ci;
zr=tzr; zi=tzi;
if ((zr*zr+zi*zi)>=4.0) break;
}
output[x+y*wid]=pixel(COMPLEX(zr,zi));

(Try this in NetRun now!)

Surprisingly, again, this speeds up the program!  Now we're at 20ms.  We have two options for SSE:
  1. Put zr and zi into a single vector, and try to extract a little parallelism within the operations on it, like the duplicate multiplies.  Most of our work would be in shuffling values around.
  2. Put four different zr's into one vector, and four different zi's into another, then do the *exact* operations above for four separate pixels at once. 
Generally, 1 is a little easier (for example, only one pixel at a time going around the loop), but 2 usually gives much higher performance.  To start with, we can just unroll the x loop 4 times, load everything over into 4-vectors (named in ALL CAPS), and use the handy little instruction _mm_movemask_ps to extract the sign bits from the 4 floats to figure out when to stop the loop:
	float cr[4], ci[4];
for (int i=0;i<4;i++) {
cr[i]=fx.evaluate(x+i,y);
ci[i]=fy.evaluate(x+i,y);
}
float zero[4]={0.0,0.0,0.0,0.0};
float two[4]={2.0,2.0,2.0,2.0};
__m128 TWO=_mm_load_ps(two);
float four[4]={4.0,4.0,4.0,4.0};
__m128 FOUR=_mm_load_ps(four);
__m128 CR=_mm_load_ps(cr), CI=_mm_load_ps(ci);
__m128 zr=_mm_load_ps(zero), zi=_mm_load_ps(zero);
int count;
enum {max_count=100};
for (count=0;count<max_count;count++) {
__m128 tzr=zr*zr-zi*zi+CR; /*subtle: don't overwrite zr yet!*/
__m128 tzi=TWO*zr*zi+CI;
zr=tzr; zi=tzi;
if (_mm_movemask_ps(zr*zr+zi*zi-FOUR)==0) break;
}
_mm_store_ps(cr,zr); _mm_store_ps(ci,zi);
for (int i=0;i<4;i++) {
output[x+i+y*wid]=pixel(COMPLEX(cr[i],ci[i]));
}

(Try this in NetRun now!)

This runs, and it's really super fast, at just 8.5 ms, but it gets the wrong answer.  The trouble is that all four floats are chained together in the loop until they're all done.  What we need is for only the floats that are less than 4.0 to keep going around the loop, and the other floats to stop changing.  We can use the standard (yet weird) SSE branch trick to fix this, essentially doing the loop work only if we should keep branching, like "if (tZR*tZR+tZI*tZI<FOUR) {ZR=tZR; ZI=tZI;}".
	float cr[4], ci[4];
for (int i=0;i<4;i++) {
cr[i]=fx.evaluate(x+i,y);
ci[i]=fy.evaluate(x+i,y);
}
float zero[4]={0.0,0.0,0.0,0.0};
float two[4]={2.0,2.0,2.0,2.0};
__m128 TWO=_mm_load_ps(two);
float four[4]={4.0,4.0,4.0,4.0};
__m128 FOUR=_mm_load_ps(four);
__m128 CR=_mm_load_ps(cr), CI=_mm_load_ps(ci);
__m128 ZR=_mm_load_ps(zero), ZI=_mm_load_ps(zero);
int count;
enum {max_count=100};
for (count=0;count<max_count;count++) {
__m128 tZR=ZR*ZR-ZI*ZI+CR; /*subtle: don't overwrite ZR yet!*/
__m128 tZI=TWO*ZR*ZI+CI;
__m128 LEN=tZR*tZR+tZI*tZI;
__m128 MASK=_mm_cmplt_ps(LEN,FOUR); /* set if len<4.0 */
ZR=_mm_or_ps(_mm_and_ps(MASK,tZR),_mm_andnot_ps(MASK,ZR));
ZI=_mm_or_ps(_mm_and_ps(MASK,tZI),_mm_andnot_ps(MASK,ZI));
if (_mm_movemask_ps(LEN-FOUR)==0) break; /* everybody's > 4 */
}
_mm_store_ps(cr,ZR); _mm_store_ps(ci,ZI);
for (int i=0;i<4;i++) {
output[x+i+y*wid]=pixel(COMPLEX(cr[i],ci[i]));
}

(Try this in NetRun now!)

This gets the right answer!  It is a tad slower, 10.9ms.  But getting the right answer is non-negotiable!

Making it Look Nice

The "floats" and "bools" class we developed today can substantially simplify the above dramatically, as well as allowing AVX (8 float) use:
	for (int x=0;x<wid;x+=floats::n) {
/* Walk this Mandelbrot Set pixel */
floats cr, ci; /* copy coordinates into floats */
for (int i=0;i<floats::n;i++) {
cr[i]=fx.evaluate(x+i,y);
ci[i]=fy.evaluate(x+i,y);
}
floats zr=0.0, zi=0.0;
int count;
enum {max_count=100};
for (count=0;count<max_count;count++) {
/* Compute next iteration values */
floats nzr=zr*zr-zi*zi+cr;
floats nzi=2.0f*zr*zi+ci;
/* Decide if we should keep iterating */
bools keepgoing=(nzr*nzr+nzi*nzi<4.0);
if (!(keepgoing.any())) break; /* everybody's done */
zr=keepgoing.if_then_else(nzr,zr); /* mixed branch */
zi=keepgoing.if_then_else(nzi,zi);
}
/* Copy coordinates back out from floats */
for (int i=0;i<floats::n;i++) {
output[x+i+y*wid]=pixel(COMPLEX(zr[i],zi[i]));
}
}
(Try this in NetRun now!)
AVX plus the new fast Sandy Bridge processor takes us down to just 2.7ms!

Here's the finished "floats" and "bools" classes:
/**
AVX implementation of Dr. Lawlor's "floats" class.
Dr. Orion Lawlor, lawlor@alaska.edu, 2011-10-04 (public domain)
*/
#ifndef __OSL_FLOATS_H__

#ifndef __AVX__
#error "You need the AVX intrinsics for this floats.h to work."
#endif
#include <immintrin.h> /* AVX */


class floats; // forward declaration

// One set of boolean values
class bools {
__m256 v; /* 8 boolean values, represented as all-zeros or all-ones 32-bit masks */
public:
bools(__m256 val) {v=val;}
__m256 get(void) const {return v;}

/* Combines sets of logical operations */
bools operator&&(const bools &rhs) const {return _mm256_and_ps(v,rhs.v);}
bools operator||(const bools &rhs) const {return _mm256_or_ps(v,rhs.v);}
bools operator!=(const bools &rhs) const {return _mm256_xor_ps(v,rhs.v);}

/* Use masking to combine the then_part (for true bools) and else part (for false bools). */
floats if_then_else(const floats &then,const floats &else_part) const;

/**
Return true if *all* our bools are equal to this single value.
*/
bool operator==(bool allvalue) const {
int m=_mm256_movemask_ps(v); /* 8 bits == high bits of each of our bools */
if (allvalue==false) return m==0; /* all false */
else /*allvalue==true*/ return m==255; /* all true (every bit set) */
}

/**
Return true if *any* of our bools are true.
*/
bool any(void) const {
return _mm256_movemask_ps(v)!=0;
}
};


/**
Represents an entire set of float values.
*/
class floats {
__m256 v; /* 8 floating point values */
public:
enum {n=8}; /* number of floats inside us */
floats() {}
floats(__m256 val) {v=val;}
void operator=(__m256 val) {v=val;}
__m256 get(void) const {return v;}
floats(float x) {v=_mm256_broadcast_ss(&x);}
void operator=(float x) {v=_mm256_broadcast_ss(&x);}

floats(const float *src) {v=_mm256_loadu_ps(src);}
void operator=(const float *src) {v=_mm256_loadu_ps(src);}

/* "Mask" load: if bool is true, load the float;
if bool is false, do not load the float and set it to zero. */
void load_mask(const float *src,const bools &mask) {v=_mm256_maskload_ps(src,mask.get());}

/** Basic arithmetic, returning floats */
friend floats operator+(const floats &lhs,const floats &rhs)
{return _mm256_add_ps(lhs.v,rhs.v);}
friend floats operator-(const floats &lhs,const floats &rhs)
{return _mm256_sub_ps(lhs.v,rhs.v);}
friend floats operator*(const floats &lhs,const floats &rhs)
{return _mm256_mul_ps(lhs.v,rhs.v);}
friend floats operator/(const floats &lhs,const floats &rhs)
{return _mm256_div_ps(lhs.v,rhs.v);}
floats operator+=(const floats &rhs) {v=_mm256_add_ps(v,rhs.v); return *this; }
floats operator-=(const floats &rhs) {v=_mm256_sub_ps(v,rhs.v); return *this; }
floats operator*=(const floats &rhs) {v=_mm256_mul_ps(v,rhs.v); return *this; }
floats operator/=(const floats &rhs) {v=_mm256_div_ps(v,rhs.v); return *this; }

/** Comparisons, returning "bools" */
bools operator==(const floats &rhs) const {return _mm256_cmp_ps(v,rhs.v,_CMP_EQ_OQ);}
bools operator!=(const floats &rhs) const {return _mm256_cmp_ps(v,rhs.v,_CMP_NEQ_OQ);}
bools operator<(const floats &rhs) const {return _mm256_cmp_ps(v,rhs.v,_CMP_LT_OQ);}
bools operator<=(const floats &rhs) const {return _mm256_cmp_ps(v,rhs.v,_CMP_LE_OQ);}
bools operator>(const floats &rhs) const {return _mm256_cmp_ps(v,rhs.v,_CMP_GT_OQ);}
bools operator>=(const floats &rhs) const {return _mm256_cmp_ps(v,rhs.v,_CMP_GE_OQ);}

/* Store to unaligned memory */
void store(float *ptr) const { _mm256_storeu_ps(ptr,v); }
/* Store with mask */
void store_mask(float *ptr,const bools &mask) const
{ _mm256_maskstore_ps(ptr,v,mask.get()); }
/* Store to 256-bit aligned memory (if not aligned, will segfault!) */
void store_aligned(float *ptr) const { _mm256_store_ps(ptr,v); }

/* Extract one float from our set. index must be between 0 and 7 */
float &operator[](int index) { return ((float *)&v)[index]; }
float operator[](int index) const { return ((const float *)&v)[index]; }

friend ostream &operator<<(ostream &o,const floats &y) {
for (int i=0;i<8;i++) o<<y[i]<<" ";
return o;
}
};

inline floats bools::if_then_else(const floats &then,const floats &else_part) const {
return _mm256_or_ps( _mm256_and_ps(v,then.get()),
_mm256_andnot_ps(v, else_part.get())
);
}
inline floats max(const floats &a,const floats &b) {
return _mm256_max_ps(a.get(),b.get());
}
inline floats min(const floats &a,const floats &b) {
return _mm256_min_ps(a.get(),b.get());
}
inline floats ceil(const floats &v) {return _mm256_ceil_ps(v.get());}
inline floats floor(const floats &v) {return _mm256_floor_ps(v.get());}
inline floats sqrt(const floats &v) {return _mm256_sqrt_ps(v.get());}
inline floats rsqrt(const floats &v) {return _mm256_rsqrt_ps(v.get());}
#endif
Feel free to use these for anything you like!