- Recent Changes 新聞
- History 歷史
- Preferences 喜好
- Discussion 討論

(An alternative title for this post is, what is the type of differentiation? Hint: it’s not quite (ℝ→ℝ)→(ℝ→ℝ), because how would you make sure the input function is differentiable?)

You can download this post as a literate Haskell program.

`{-# LANGUAGE Rank2Types #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE OverlappingInstances #-} module Differentiation where`

## Automatic differentiation

*The overloading approach to automatic differentiation*
has become
more
popular recently among typed functional programmers. The basic
idea is to overload arithmetic operators such as +, ×, and sin so
that they work on not just floating-point numbers but
*pairs* (or more generally, *sequences*) of them,
which track quantities along with their rates of change. The
overloaded operators are easy to define because, unlike with
integration, the rules of differentiation are compositional: you
know, d(x + y) = dx + dy,
d(x × y) = dx × y + x × dy,
d(sin x) = cos x × dx, and so on. Here
they are in Haskell.

`data D a = D a a deriving Show lift :: Num a => a -> D a lift x = D x 0 infinitesimal :: Num a => D a infinitesimal = D 0 1 instance Eq a => Eq (D a) where D x _ == D y _ = x == y instance Ord a => Ord (D a) where compare (D x _) (D y _) = compare x y instance Num a => Num (D a) where D x x' + D y y' = D (x + y) (x' + y') D x x' * D y y' = D (x * y) (x' * y + x * y') negate (D x x') = D (negate x) (negate x') abs (D x x') = D (abs x) (signum x * x') signum (D x _) = lift (signum x) fromInteger x = lift (fromInteger x) instance Fractional a => Fractional (D a) where recip (D x x') = D (recip x) (-x'/x/x) fromRational x = lift (fromRational x)`

The two components of a `D`

value are a quantity and
its derivative, so the `lift`

function ‘lifts’ a number
to a constant quantity, and `infinitesimal`

is a
quantity with value 0 and derivative 1. Were `abs`

defined by `abs x = signum x * x`

by default, we
wouldn’t have to define `abs`

for `D`

above.

Let us use these definitions to model how a parabola reflects light.

Suppose that a light ray (red above) enters a parabolic mirror
*y* = *x*²/4 from above. Where does the
reflected ray cross the *y* axis? In the diagram above, if
the point where the ray hits the parabola is
(*x*,*y*), then the derivative of *y* with
respect to *x* is tan *θ*, and the *y*
coordinate where the reflected ray crosses the *y* axis is
*y* + x/tan 2*θ*, which is equal to
*y* + x (1/tan *θ* −
tan *θ*)/2. We can compute this coordinate by
automatically differentiating *y* with respect to
*x*:

`curve x = x^2/4 reflect x = let D y y' = curve (lift x + infinitesimal) in y + x * (recip y' - y') / 2`

As expected, the parabola reflects all incoming rays from above to the focal point (0,1).

`*Differentiation> map reflect [1..5] [1.0,1.0,1.0,1.0,1.0]`

To be sure, this code does not work by symbolically
differentiating *y* = *x*²/4 to yield
*y*′ = *x*/2. Rather, it computes *y*
alongside *y*′ for one particular *x* at a time, so
the curve could just as well be defined by a more complex program
that uses `if`

and recursion. For example, a ray tracer
usually deals with scenes much more complex than a single parabola.
This code also does not work by numerically comparing the values of
*y* at nearby values of *x*. Instead of
approximating dy/dx by Δy/Δx where Δx is a very small real number,
we compute with an actually infinitesimal dx.

## Differentiation is a higher-order function

Before we continue, let us abstract the pattern for
differentiation in `reflect`

above into a new
higher-order function `d`

, which differentiates any
given function at the input 0. For convenience, `d`

also returns the value of the given function at 0.

`d :: Num a => (D a -> D b) -> (b, b) d f = let D y y' = f infinitesimal in (y,y') reflect :: Fractional a => a -> a reflect x = let (y,y') = d (\h -> curve (lift x + h)) in y + x * (recip y' - y') / 2`

Even though the `reflect`

function uses
differentiation internally, we can still differentiate it. Such
differentiate is said to be *nested*. For the parabola, we
can confirm that the reflected ray hits the focal point not just at
but also around *x* = 3, because the derivative
computed below is zero.

`*Differentiation> d (\k -> reflect (3 + k)) (1.0,0.0)`

For other curves and surfaces, the derivative is typically not
zero and tells us the density of light energy that falls around
each point on the *y* axis.
Hence, as Dan Piponi noted, this kind of calculation is performed
by ray tracers and other programs that sample from probability
distributions.

Another application of automatic differentiation is to find roots of a function using Newton’s method. Therefore, we can use nested automatic differentiation to find local extrema and saddle points of a function using Newton’s method.

## The danger of confusing infinitesimals

Jeff Siskind
and Barak Pearlmutter pointed out a kind of programmer mistake that
makes nested differentiation give the wrong result. As they
show, this kind of mistake is easy to make in the framework defined
above, because all it takes is putting a call to the
`lift`

function in the wrong place. The definition of
`reflectBug`

below is only slightly different from
`reflect`

, but the result is very different and very
wrong.

`reflectBug x = let (y,y') = d (\h -> lift (curve (x + h))) in y + x * (recip y' - y') / 2 *Differentiation> d (\k -> reflectBug (3 + k)) (Infinity,NaN)`

They demonstrate this problem using running code in Haskell and Scheme that computes

to be 2 rather than the correct answer 1.

The essence of this problem is that the two nested invocations
of differentiation use two different infinitesimals, which a
mathematician would denote by *dh* and *dk*.
These two infinitesimals should not be confused, just as years and
feet and persons should not be confused.

## Using types to check units

Björn
Buckwalter showed that we can use a generic type system to prevent
such confusion statically, just as Haskell uses *state
threads* to distinguish pointers into different memory
regions. Recall that the type `ST s a`

in Haskell
represents a monadic computation that yields a result of
type `a`

using mutable cells in the state thread
represented by the phantom type `s`

. To construct
such a computation, we can use primitive operations such as

`newSTRef :: a -> ST s (STRef s a)`

and the fact that `ST s`

is a monad. To run such a
computation, we must use the primitive function

`runST :: (forall s. ST s a) -> a`

in which `forall s`

forces different state threads,
created by different calls to `runST`

, to be represented
by different phantom types `s`

. One way to
understand this rank-2 type is that it makes the type checker
generate a new phantom type `s`

for each argument
to `runST`

. Analogously, Buckwalter redefines the type
constructor `D`

to take a phantom-type
argument `s`

, which represents an infinitesimal
unit.

`data D s a = D a a deriving Show`

Accordingly, the other definitions above change in their types,
but not in their terms or behavior. The most important change is
the new rank-2 type of `d`

, which forces different
infinitesimals, created by different invocations of
differentiation, to be represented by different phantom
types `s`

.

`d :: Num a => (forall s. D s a -> D s a) -> (a, a) d f = let D y y' = f infinitesimal in (y,y') lift :: Num a => a -> D s a lift x = D x 0 infinitesimal :: Num a => D s a infinitesimal = D 0 1 instance Eq a => Eq (D s a) where D x _ == D y _ = x == y instance Ord a => Ord (D s a) where compare (D x _) (D y _) = compare x y instance Num a => Num (D s a) where D x x' + D y y' = D (x + y) (x' + y') D x x' * D y y' = D (x * y) (x' * y + x * y') negate (D x x') = D (negate x) (negate x') abs (D x x') = D (abs x) (signum x * x') signum (D x _) = lift (signum x) fromInteger x = lift (fromInteger x) instance Fractional a => Fractional (D s a) where recip (D x x') = D (recip x) (-x'/x/x) fromRational x = lift (fromRational x)`

Because the phantom type `s`

is part of the type
of a number, and because arithmetic operations such as
`+`

require the arguments and the return value to have
the same type, it is a type error to add numbers denominated in
different infinitesimals. In particular, the erroneous definition
`reflectBug`

above is now a type error, as desired.

`Occurs check: cannot construct the infinite type: t = D s t Expected type: t Inferred type: D s t In the first argument of `lift', namely `(curve (x + h))' In the expression: lift (curve (x + h))`

For this checking of infinitesimal units to be sound, this
library for automatic differentiation should not export the values
`D`

and `infinitesimal`

to its users, though
of course the type constructor `D`

and its type-class
instances need to be exported, along with the functions
`d`

and `lift`

.

`d :: Num a => (forall b. Num b => (a -> b) -> b -> b) -> (a, a)`

for differentiation. The type `b`

makes it
unnecessary and useless to export the type
constructor `D`

, even though `d`

is
still implemented using `D`

. Also, the new argument
of type `a`

`->`

`b`

makes it unnecessary and useless to export the `lift`

function. Finally, the type-class context `Num b`

makes
it unnecessary and useless to export the `Eq`

,
`Show`

, and `Num`

instances
for `D`

.
However, as Chris Smith lamented, we need additional
differentiation functions of the types

`Fractional a => (forall b. Fractional b => (a -> b) -> b -> b) -> (a, a) Floating a => (forall b. Floating b => (a -> b) -> b -> b) -> (a, a)`

in order to differentiate functions that use
`Fractional`

or `Floating`

operations.
Oleg
Kiselyov used similar types to express *symbolic*
differentiation.)

## Automatic lifting

Although the type system now prevents us from putting calls to
`lift`

in the wrong place, it is still annoying to have
to invoke `lift`

manually—especially for nested
differentiation, a useful case as discussed above. Depending on
‘how constant’ a quantity is, we need to feed it through a
composition of exactly the right number of `lift`

s. This
manual coding is frustrating because the unique right number of
`lift`

s to apply is obvious from the input and output
types desired: to convert a type `a`

to the type
`D s a`

, apply `lift`

once; to convert
`a`

to `D s (D s' a)`

, apply
`lift`

twice; and so on. We want the compiler to manage
these *subtyping coercions* automatically.

An analogous situation arises with state threads, which can be
organized into a hierarchy of memory regions. As
part of a monadic computation that uses mutable cells in a
*parent* region, we can create a *child* region and
perform a subcomputation that allocates and accesses mutable cells
in *both* regions. After the subcomputation completes, the
child region is destroyed en bloc, but we can still use the parent
region and observe any effect on it brought about by the
subcomputation. To allow the subcomputation to use the parent
region, we want every region to be a *subtype* of its
descendents. Matthew Fluet
and Greg Morrisett’s implementation of nested regions in Haskell
uses explicit subtyping coercions just like our
`lift`

: depending on ‘how senior’ a region is, we need
to compose exactly the right number of region coercions.

In a
pending submission to the Haskell symposium, Oleg and I show how to
automate region subtyping coercions using type classes. One
might hope to apply that approach to lifting in automatic
differentiation. Indeed we can, but I only know how to automate
counting `lift`

s, not how to automate placing them. That
is, instead of feeding each use of an input quantity to the
`lift`

function exactly the right number of times, we
can feed it to a new function once. The new function, called
`lifts`

, belongs to a new type class `Lifts`

,
which takes two type parameters. The constraint ```
Lifts a
b
```

holds if and only if the type `b`

is the
result of applying zero or more type constructors `D s`

to the type `a`

.

`class Lifts a b where lifts :: a -> b`

More concretely, the following instances incompletely
approximate the intended meaning of `Lifts`

.

`instance Lifts a a where lifts = id instance Num a => Lifts a (D s a) where lifts = lift instance Num a => Lifts a (D s (D s' a)) where lifts = lift . lift instance Num a => Lifts a (D s (D s' (D s'' a))) where lifts = lift . lift . lift`

The definition above of `reflect`

in terms
of `d`

applies `lift`

once to one
occurrence of `x`

but not to other occurrences of
`x`

and `h`

. The same function can be
expressed using `Lifts`

, by applying `lifts`

once to each occurrences of the input variables `x`

and `h`

.

`reflectAuto :: Fractional a => a -> a reflectAuto x = let (y,y') = d (\h -> curve (lifts x + lifts h)) in y + lifts x * (recip y' - y') / 2`

Expressing `reflect`

in this new way frees us from
counting how many times to `lift`

the inputs
`x`

and `h`

each time they are used.

How to implement `Lifts`

? On one hand, as an
implementation of `Lifts`

, the approximate instances
above are incomplete and unsatisfactory in theory, in that they
restrict how many `lift`

each `lifts`

can
stand for. They are perfectly useful in practice, however, and rely
on no extension to Haskell other than rank-2 types and
multiparameter type classes with flexible instances. On the other
hand, a complete implementation is possible using the
`TypeCast`

class for type improvement (originally used by Kiselyov,
Lämmel, and Schupke to implement heterogeneous collections),
but it requires more Haskell extensions: functional dependencies,
overlapping instances, and undecidable instances. Without further
ado, below is the complete implementation.

`instance Lifts a a where lifts a = a instance (TypeCast (D s b') b, Num b', Lifts a b') => Lifts a b where lifts = typeCast . lift . lifts class TypeCast a b | a -> b, b -> a where typeCast :: a -> b class TypeCast' t a b | t a -> b, t b -> a where typeCast' :: t -> a -> b class TypeCast'' t a b | t a -> b, t b -> a where typeCast'' :: t -> a -> b instance TypeCast' () a b => TypeCast a b where typeCast x = typeCast' () x instance TypeCast'' t a b => TypeCast' t a b where typeCast' = typeCast'' instance TypeCast'' () a a where typeCast'' _ x = x`

Although `Lifts`

makes it easier to use automatic
differentiation, this implementation is heavy lifting. I wonder if
it is easier to express this combination of subtyping and rank-2
polymorphism in a language like Scala?