Java Floating Point: strictfp, StrictMath vs. LingPipe’s Forward-Backward Algorithm


(Note: there are two questions I have, highlighted in bold.)

I’ve been wrestling with varying floating point behavior across platforms, a bug feature of modern programming languages over multiple platforms and compilers. Here’s a nice intro from Sun that pre-dates Java, but discusses the IEEE floating point standard, the main platform-based variations, and their effects on the C language:

David Goldberg. 1991. What Every Computer Scientist Should Know About Floating-Point Arithmetic. ACM Computing Surveys. Here’s the whole book as a PDF.

It’s just the appendix; the rest of the doc goes into gory detail about floating point. Much more so than in my numerical optimization (Nocedal and Wright) or matrix computation (Golub and van Loan) books.

One of the features of LingPipe’s implementation of hidden Markov models (HMMs) is a full implementation of the forward-backward algorithm (also see Bishop’s pattern recognition textbook). Forward-backward is just a special case of the general sum-product algorithm for graphical models. It’s also used in conditional random fields (CRFs). Forward-backward is what allows us to calculate the probability of a tag for a token given the whole sequence of input tokens. It’s also what’s behind the same calculation for named entities in our HMM named-entity extractor.

As I mentioned previously, in a post on POS tagger eval, Matthew Wilkins has a corpus of 1.6M tokens of literary text with POS tags. During the evaluation, our HMM decoder threw an exception in fold 347 of a 500-fold cross-validation run (it was so many folds so as not to blow out memory storing cases — I’ll make storage optional in the next release).

I fixed the problem by adding a strictfp declaration to every class, and replacing all calls to java.lang.Math with the corresponding calls to java.lang.StrictMath. It wasn’t noticeably slower to decode on my desktop machine (dual quad-core Xeon, Vista64), in either 1.5 or 1.6. Is this what others have found with strict math? What I read on the web led me to believe it’d be much slower.

The bug was in the middle the forward portion of the forward-backward algorithm. Forward-backward’s a nasty algorithm numerical-computation-wise because it involves a sequence of multiply-adds (basically vector dot products, where the output of one set of computations makes up the input to the next). Here’s how the forward values (alphas) are defined for a given token position (n) and category/tag (t):

alpha_{t,0}  =  p(t|start)  p(w_0|t)

alpha_{t,n} = sum_{t'} alpha_{t',n-1}  p(t|t')  p(w_n|t)

For each input token (n), there are multiply-adds accumulated over all of the tags (t), each involving the output of the previous (n-1) calculation! The algorithm’s not as bad as it looks because we store the previous alphas and use dynamic programming.

Because the forward values (alphas) are joint probabilities of the input tokens and the tag assigned at a position, they can easily underflow in practice on a linear scale. That’s why Viterbi (first best algorithm) uses log scales (that, and it converts multiplication to addition).

In practice, we really only care about the conditional probabilities. So we can rescale:

alpha'_{t,n} = frac{alpha_{t,n}}{sum_{t'} alpha_{t',n}}

As those of you who follow floating point know, this is all risky business. What happened in practice was that the rescaled values summed to 1.09 at one point (I only rescaled up, so this is a problem in the forward calculations at some point).

Finally: Does anyone know any references about calculating forward-backward in a way that’s sensitive to how floating point arithmetic actually works?

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: