The former function, log1pexp
, utilizes a custom Op
, and those–just like specialized NumPy functions–will often perform seemingly standard operations with a lot more care. In other words, the two expressions you’re comparing can produce quite different results in certain cases.
Yes, one of the main purposes of Theano’s rewrites/“optimizations” is to make such replacements automatically.
What I was saying is that I don’t know if the exact rewrites you’re considering are all implemented right now. If some are missing, then we can add them, but we need to be smart about how we add them–especially given the rewrites that are already present.
Here’s a quick way to determine experimentally which optimizations are present (in fast_run
mode, at least):
import theano.tensor as tt
from theano import config
from theano.gof.graph import inputs as tt_inputs
from theano.gof.fg import FunctionGraph
from theano.gof.optdb import Query
from theano.compile import optdb
from theano.printing import debugprint as tt_dprint
# We don't need to waste time compiling graphs to C
config.cxx = ""
def optimize_graphs(*graphs, include=["fast_run"], **kwargs):
inputs = tt_inputs(graphs)
graphs = list(graphs)
fgraph = FunctionGraph(inputs, graphs, clone=False)
canonicalize_opt = optdb.query(Query(include=include, **kwargs))
_ = canonicalize_opt.optimize(fgraph)
return graphs
x = tt.vector()
y = tt.log(1 + tt.exp(x))
>>> # This will tell us which rewrites were used
>>> with config.change_flags(optimizer_verbose=True):
>>> # Perform the rewrites
>>> (z,) = optimize_graphs(y)
local_log1p Elemwise{log,no_inplace}.0 Elemwise{second,no_inplace}.0
local_fill_to_alloc Elemwise{second,no_inplace}.0 Elemwise{log1p,no_inplace}.0
Elemwise{log1p,no_inplace}(Elemwise{exp,no_inplace}(x)) -> softplus(x) Elemwise{log1p,no_inplace}.0 softplus.0
...
inplace_elemwise_optimizer softplus.0 Elemwise{ScalarSoftplus}[(0, 0)].0
>>> tt_dprint(z)
Elemwise{ScalarSoftplus}[(0, 0)] [id A] ''
|<TensorType(float64, vector)> [id B]
The output tells us that tt.log(1 + tt.exp(x))
is rewritten to tt.nnet.softplus(x)
, as desired.