mirror of
https://github.com/tildearrow/furnace.git
synced 2024-12-04 18:27:25 +00:00
581 lines
18 KiB
OCaml
581 lines
18 KiB
OCaml
|
(*
|
||
|
* Copyright (c) 1997-1999 Massachusetts Institute of Technology
|
||
|
* Copyright (c) 2003, 2007-14 Matteo Frigo
|
||
|
* Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
|
||
|
*
|
||
|
* This program is free software; you can redistribute it and/or modify
|
||
|
* it under the terms of the GNU General Public License as published by
|
||
|
* the Free Software Foundation; either version 2 of the License, or
|
||
|
* (at your option) any later version.
|
||
|
*
|
||
|
* This program is distributed in the hope that it will be useful,
|
||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||
|
* GNU General Public License for more details.
|
||
|
*
|
||
|
* You should have received a copy of the GNU General Public License
|
||
|
* along with this program; if not, write to the Free Software
|
||
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||
|
*
|
||
|
*)
|
||
|
|
||
|
|
||
|
open Util
|
||
|
open Expr
|
||
|
|
||
|
let node_insert x = Assoctable.insert Expr.hash x
|
||
|
let node_lookup x = Assoctable.lookup Expr.hash (==) x
|
||
|
|
||
|
(*************************************************************
|
||
|
* Algebraic simplifier/elimination of common subexpressions
|
||
|
*************************************************************)
|
||
|
module AlgSimp : sig
|
||
|
val algsimp : expr list -> expr list
|
||
|
end = struct
|
||
|
|
||
|
open Monads.StateMonad
|
||
|
open Monads.MemoMonad
|
||
|
open Assoctable
|
||
|
|
||
|
let fetchSimp =
|
||
|
fetchState >>= fun (s, _) -> returnM s
|
||
|
let storeSimp s =
|
||
|
fetchState >>= (fun (_, c) -> storeState (s, c))
|
||
|
let lookupSimpM key =
|
||
|
fetchSimp >>= fun table ->
|
||
|
returnM (node_lookup key table)
|
||
|
let insertSimpM key value =
|
||
|
fetchSimp >>= fun table ->
|
||
|
storeSimp (node_insert key value table)
|
||
|
|
||
|
let subset a b =
|
||
|
List.for_all (fun x -> List.exists (fun y -> x == y) b) a
|
||
|
|
||
|
let structurallyEqualCSE a b =
|
||
|
match (a, b) with
|
||
|
| (Num a, Num b) -> Number.equal a b
|
||
|
| (NaN a, NaN b) -> a == b
|
||
|
| (Load a, Load b) -> Variable.same a b
|
||
|
| (Times (a, a'), Times (b, b')) ->
|
||
|
((a == b) && (a' == b')) ||
|
||
|
((a == b') && (a' == b))
|
||
|
| (CTimes (a, a'), CTimes (b, b')) ->
|
||
|
((a == b) && (a' == b')) ||
|
||
|
((a == b') && (a' == b))
|
||
|
| (CTimesJ (a, a'), CTimesJ (b, b')) -> ((a == b) && (a' == b'))
|
||
|
| (Plus a, Plus b) -> subset a b && subset b a
|
||
|
| (Uminus a, Uminus b) -> (a == b)
|
||
|
| _ -> false
|
||
|
|
||
|
let hashCSE x =
|
||
|
if (!Magic.randomized_cse) then
|
||
|
Oracle.hash x
|
||
|
else
|
||
|
Expr.hash x
|
||
|
|
||
|
let equalCSE a b =
|
||
|
if (!Magic.randomized_cse) then
|
||
|
(structurallyEqualCSE a b || Oracle.likely_equal a b)
|
||
|
else
|
||
|
structurallyEqualCSE a b
|
||
|
|
||
|
let fetchCSE =
|
||
|
fetchState >>= fun (_, c) -> returnM c
|
||
|
let storeCSE c =
|
||
|
fetchState >>= (fun (s, _) -> storeState (s, c))
|
||
|
let lookupCSEM key =
|
||
|
fetchCSE >>= fun table ->
|
||
|
returnM (Assoctable.lookup hashCSE equalCSE key table)
|
||
|
let insertCSEM key value =
|
||
|
fetchCSE >>= fun table ->
|
||
|
storeCSE (Assoctable.insert hashCSE key value table)
|
||
|
|
||
|
(* memoize both x and Uminus x (unless x is already negated) *)
|
||
|
let identityM x =
|
||
|
let memo x = memoizing lookupCSEM insertCSEM returnM x in
|
||
|
match x with
|
||
|
Uminus _ -> memo x
|
||
|
| _ -> memo x >>= fun x' -> memo (Uminus x') >> returnM x'
|
||
|
|
||
|
let makeNode = identityM
|
||
|
|
||
|
(* simplifiers for various kinds of nodes *)
|
||
|
let rec snumM = function
|
||
|
n when Number.is_zero n ->
|
||
|
makeNode (Num (Number.zero))
|
||
|
| n when Number.negative n ->
|
||
|
makeNode (Num (Number.negate n)) >>= suminusM
|
||
|
| n -> makeNode (Num n)
|
||
|
|
||
|
and suminusM = function
|
||
|
Uminus x -> makeNode x
|
||
|
| Num a when (Number.is_zero a) -> snumM Number.zero
|
||
|
| a -> makeNode (Uminus a)
|
||
|
|
||
|
and stimesM = function
|
||
|
| (Uminus a, b) -> stimesM (a, b) >>= suminusM
|
||
|
| (a, Uminus b) -> stimesM (a, b) >>= suminusM
|
||
|
| (NaN I, CTimes (a, b)) -> stimesM (NaN I, b) >>=
|
||
|
fun ib -> sctimesM (a, ib)
|
||
|
| (NaN I, CTimesJ (a, b)) -> stimesM (NaN I, b) >>=
|
||
|
fun ib -> sctimesjM (a, ib)
|
||
|
| (Num a, Num b) -> snumM (Number.mul a b)
|
||
|
| (Num a, Times (Num b, c)) ->
|
||
|
snumM (Number.mul a b) >>= fun x -> stimesM (x, c)
|
||
|
| (Num a, b) when Number.is_zero a -> snumM Number.zero
|
||
|
| (Num a, b) when Number.is_one a -> makeNode b
|
||
|
| (Num a, b) when Number.is_mone a -> suminusM b
|
||
|
| (a, b) when is_known_constant b && not (is_known_constant a) ->
|
||
|
stimesM (b, a)
|
||
|
| (a, b) -> makeNode (Times (a, b))
|
||
|
|
||
|
and sctimesM = function
|
||
|
| (Uminus a, b) -> sctimesM (a, b) >>= suminusM
|
||
|
| (a, Uminus b) -> sctimesM (a, b) >>= suminusM
|
||
|
| (a, b) -> makeNode (CTimes (a, b))
|
||
|
|
||
|
and sctimesjM = function
|
||
|
| (Uminus a, b) -> sctimesjM (a, b) >>= suminusM
|
||
|
| (a, Uminus b) -> sctimesjM (a, b) >>= suminusM
|
||
|
| (a, b) -> makeNode (CTimesJ (a, b))
|
||
|
|
||
|
and reduce_sumM x = match x with
|
||
|
[] -> returnM []
|
||
|
| [Num a] ->
|
||
|
if (Number.is_zero a) then
|
||
|
returnM []
|
||
|
else returnM x
|
||
|
| [Uminus (Num a)] ->
|
||
|
if (Number.is_zero a) then
|
||
|
returnM []
|
||
|
else returnM x
|
||
|
| (Num a) :: (Num b) :: s ->
|
||
|
snumM (Number.add a b) >>= fun x ->
|
||
|
reduce_sumM (x :: s)
|
||
|
| (Num a) :: (Uminus (Num b)) :: s ->
|
||
|
snumM (Number.sub a b) >>= fun x ->
|
||
|
reduce_sumM (x :: s)
|
||
|
| (Uminus (Num a)) :: (Num b) :: s ->
|
||
|
snumM (Number.sub b a) >>= fun x ->
|
||
|
reduce_sumM (x :: s)
|
||
|
| (Uminus (Num a)) :: (Uminus (Num b)) :: s ->
|
||
|
snumM (Number.add a b) >>=
|
||
|
suminusM >>= fun x ->
|
||
|
reduce_sumM (x :: s)
|
||
|
| ((Num _) as a) :: b :: s -> reduce_sumM (b :: a :: s)
|
||
|
| ((Uminus (Num _)) as a) :: b :: s -> reduce_sumM (b :: a :: s)
|
||
|
| a :: s ->
|
||
|
reduce_sumM s >>= fun s' -> returnM (a :: s')
|
||
|
|
||
|
and collectible1 = function
|
||
|
| NaN _ -> false
|
||
|
| Uminus x -> collectible1 x
|
||
|
| _ -> true
|
||
|
and collectible (a, b) = collectible1 a
|
||
|
|
||
|
(* collect common factors: ax + bx -> (a+b)x *)
|
||
|
and collectM which x =
|
||
|
let rec findCoeffM which = function
|
||
|
| Times (a, b) when collectible (which (a, b)) -> returnM (which (a, b))
|
||
|
| Uminus x ->
|
||
|
findCoeffM which x >>= fun (coeff, b) ->
|
||
|
suminusM coeff >>= fun mcoeff ->
|
||
|
returnM (mcoeff, b)
|
||
|
| x -> snumM Number.one >>= fun one -> returnM (one, x)
|
||
|
and separateM xpr = function
|
||
|
[] -> returnM ([], [])
|
||
|
| a :: b ->
|
||
|
separateM xpr b >>= fun (w, wo) ->
|
||
|
(* try first factor *)
|
||
|
findCoeffM (fun (a, b) -> (a, b)) a >>= fun (c, x) ->
|
||
|
if (xpr == x) && collectible (c, x) then returnM (c :: w, wo)
|
||
|
else
|
||
|
(* try second factor *)
|
||
|
findCoeffM (fun (a, b) -> (b, a)) a >>= fun (c, x) ->
|
||
|
if (xpr == x) && collectible (c, x) then returnM (c :: w, wo)
|
||
|
else returnM (w, a :: wo)
|
||
|
in match x with
|
||
|
[] -> returnM x
|
||
|
| [a] -> returnM x
|
||
|
| a :: b ->
|
||
|
findCoeffM which a >>= fun (_, xpr) ->
|
||
|
separateM xpr x >>= fun (w, wo) ->
|
||
|
collectM which wo >>= fun wo' ->
|
||
|
splusM w >>= fun w' ->
|
||
|
stimesM (w', xpr) >>= fun t' ->
|
||
|
returnM (t':: wo')
|
||
|
|
||
|
and mangleSumM x = returnM x
|
||
|
>>= reduce_sumM
|
||
|
>>= collectM (fun (a, b) -> (a, b))
|
||
|
>>= collectM (fun (a, b) -> (b, a))
|
||
|
>>= reduce_sumM
|
||
|
>>= deepCollectM !Magic.deep_collect_depth
|
||
|
>>= reduce_sumM
|
||
|
|
||
|
and reorder_uminus = function (* push all Uminuses to the end *)
|
||
|
[] -> []
|
||
|
| ((Uminus _) as a' :: b) -> (reorder_uminus b) @ [a']
|
||
|
| (a :: b) -> a :: (reorder_uminus b)
|
||
|
|
||
|
and canonicalizeM = function
|
||
|
[] -> snumM Number.zero
|
||
|
| [a] -> makeNode a (* one term *)
|
||
|
| a -> generateFusedMultAddM (reorder_uminus a)
|
||
|
|
||
|
and generateFusedMultAddM =
|
||
|
let rec is_multiplication = function
|
||
|
| Times (Num a, b) -> true
|
||
|
| Uminus (Times (Num a, b)) -> true
|
||
|
| _ -> false
|
||
|
and separate = function
|
||
|
[] -> ([], [], Number.zero)
|
||
|
| (Times (Num a, b)) as this :: c ->
|
||
|
let (x, y, max) = separate c in
|
||
|
let newmax = if (Number.greater a max) then a else max in
|
||
|
(this :: x, y, newmax)
|
||
|
| (Uminus (Times (Num a, b))) as this :: c ->
|
||
|
let (x, y, max) = separate c in
|
||
|
let newmax = if (Number.greater a max) then a else max in
|
||
|
(this :: x, y, newmax)
|
||
|
| this :: c ->
|
||
|
let (x, y, max) = separate c in
|
||
|
(x, this :: y, max)
|
||
|
in fun l ->
|
||
|
if !Magic.enable_fma && count is_multiplication l >= 2 then
|
||
|
let (w, wo, max) = separate l in
|
||
|
snumM (Number.div Number.one max) >>= fun invmax' ->
|
||
|
snumM max >>= fun max' ->
|
||
|
mapM (fun x -> stimesM (invmax', x)) w >>= splusM >>= fun pw' ->
|
||
|
stimesM (max', pw') >>= fun mw' ->
|
||
|
splusM (wo @ [mw'])
|
||
|
else
|
||
|
makeNode (Plus l)
|
||
|
|
||
|
|
||
|
and negative = function
|
||
|
Uminus _ -> true
|
||
|
| _ -> false
|
||
|
|
||
|
(*
|
||
|
* simplify patterns of the form
|
||
|
*
|
||
|
* ((c_1 * a + ...) + ...) + (c_2 * a + ...)
|
||
|
*
|
||
|
* The pattern includes arbitrary coefficients and minus signs.
|
||
|
* A common case of this pattern is the butterfly
|
||
|
* (a + b) + (a - b)
|
||
|
* (a + b) - (a - b)
|
||
|
*)
|
||
|
(* this whole procedure needs much more thought *)
|
||
|
and deepCollectM maxdepth l =
|
||
|
let rec findTerms depth x = match x with
|
||
|
| Uminus x -> findTerms depth x
|
||
|
| Times (Num _, b) -> (findTerms (depth - 1) b)
|
||
|
| Plus l when depth > 0 ->
|
||
|
x :: List.flatten (List.map (findTerms (depth - 1)) l)
|
||
|
| x -> [x]
|
||
|
and duplicates = function
|
||
|
[] -> []
|
||
|
| a :: b -> if List.memq a b then a :: duplicates b
|
||
|
else duplicates b
|
||
|
|
||
|
in let rec splitDuplicates depth d x =
|
||
|
if (List.memq x d) then
|
||
|
snumM (Number.zero) >>= fun zero ->
|
||
|
returnM (zero, x)
|
||
|
else match x with
|
||
|
| Times (a, b) ->
|
||
|
splitDuplicates (depth - 1) d a >>= fun (a', xa) ->
|
||
|
splitDuplicates (depth - 1) d b >>= fun (b', xb) ->
|
||
|
stimesM (a', b') >>= fun ab ->
|
||
|
stimesM (a, xb) >>= fun xb' ->
|
||
|
stimesM (xa, b) >>= fun xa' ->
|
||
|
stimesM (xa, xb) >>= fun xab ->
|
||
|
splusM [xa'; xb'; xab] >>= fun x ->
|
||
|
returnM (ab, x)
|
||
|
| Uminus a ->
|
||
|
splitDuplicates depth d a >>= fun (x, y) ->
|
||
|
suminusM x >>= fun ux ->
|
||
|
suminusM y >>= fun uy ->
|
||
|
returnM (ux, uy)
|
||
|
| Plus l when depth > 0 ->
|
||
|
mapM (splitDuplicates (depth - 1) d) l >>= fun ld ->
|
||
|
let (l', d') = List.split ld in
|
||
|
splusM l' >>= fun p ->
|
||
|
splusM d' >>= fun d'' ->
|
||
|
returnM (p, d'')
|
||
|
| x ->
|
||
|
snumM (Number.zero) >>= fun zero' ->
|
||
|
returnM (x, zero')
|
||
|
|
||
|
in let l' = List.flatten (List.map (findTerms maxdepth) l)
|
||
|
in match duplicates l' with
|
||
|
| [] -> returnM l
|
||
|
| d ->
|
||
|
mapM (splitDuplicates maxdepth d) l >>= fun ld ->
|
||
|
let (l', d') = List.split ld in
|
||
|
splusM l' >>= fun l'' ->
|
||
|
let rec flattenPlusM = function
|
||
|
| Plus l -> returnM l
|
||
|
| Uminus x ->
|
||
|
flattenPlusM x >>= mapM suminusM
|
||
|
| x -> returnM [x]
|
||
|
in
|
||
|
mapM flattenPlusM d' >>= fun d'' ->
|
||
|
splusM (List.flatten d'') >>= fun d''' ->
|
||
|
mangleSumM [l''; d''']
|
||
|
|
||
|
and splusM l =
|
||
|
let fma_heuristics x =
|
||
|
if !Magic.enable_fma then
|
||
|
match x with
|
||
|
| [Uminus (Times _); Times _] -> Some false
|
||
|
| [Times _; Uminus (Times _)] -> Some false
|
||
|
| [Uminus (_); Times _] -> Some true
|
||
|
| [Times _; Uminus (Plus _)] -> Some true
|
||
|
| [_; Uminus (Times _)] -> Some false
|
||
|
| [Uminus (Times _); _] -> Some false
|
||
|
| _ -> None
|
||
|
else
|
||
|
None
|
||
|
in
|
||
|
mangleSumM l >>= fun l' ->
|
||
|
(* no terms are negative. Don't do anything *)
|
||
|
if not (List.exists negative l') then
|
||
|
canonicalizeM l'
|
||
|
(* all terms are negative. Negate them all and collect the minus sign *)
|
||
|
else if List.for_all negative l' then
|
||
|
mapM suminusM l' >>= splusM >>= suminusM
|
||
|
else match fma_heuristics l' with
|
||
|
| Some true -> mapM suminusM l' >>= splusM >>= suminusM
|
||
|
| Some false -> canonicalizeM l'
|
||
|
| None ->
|
||
|
(* Ask the Oracle for the canonical form *)
|
||
|
if (not !Magic.randomized_cse) &&
|
||
|
Oracle.should_flip_sign (Plus l') then
|
||
|
mapM suminusM l' >>= splusM >>= suminusM
|
||
|
else
|
||
|
canonicalizeM l'
|
||
|
|
||
|
(* monadic style algebraic simplifier for the dag *)
|
||
|
let rec algsimpM x =
|
||
|
memoizing lookupSimpM insertSimpM
|
||
|
(function
|
||
|
| Num a -> snumM a
|
||
|
| NaN _ as x -> makeNode x
|
||
|
| Plus a ->
|
||
|
mapM algsimpM a >>= splusM
|
||
|
| Times (a, b) ->
|
||
|
(algsimpM a >>= fun a' ->
|
||
|
algsimpM b >>= fun b' ->
|
||
|
stimesM (a', b'))
|
||
|
| CTimes (a, b) ->
|
||
|
(algsimpM a >>= fun a' ->
|
||
|
algsimpM b >>= fun b' ->
|
||
|
sctimesM (a', b'))
|
||
|
| CTimesJ (a, b) ->
|
||
|
(algsimpM a >>= fun a' ->
|
||
|
algsimpM b >>= fun b' ->
|
||
|
sctimesjM (a', b'))
|
||
|
| Uminus a ->
|
||
|
algsimpM a >>= suminusM
|
||
|
| Store (v, a) ->
|
||
|
algsimpM a >>= fun a' ->
|
||
|
makeNode (Store (v, a'))
|
||
|
| Load _ as x -> makeNode x)
|
||
|
x
|
||
|
|
||
|
let initialTable = (empty, empty)
|
||
|
let simp_roots = mapM algsimpM
|
||
|
let algsimp = runM initialTable simp_roots
|
||
|
end
|
||
|
|
||
|
(*************************************************************
|
||
|
* Network transposition algorithm
|
||
|
*************************************************************)
|
||
|
module Transpose = struct
|
||
|
open Monads.StateMonad
|
||
|
open Monads.MemoMonad
|
||
|
open Littlesimp
|
||
|
|
||
|
let fetchDuals = fetchState
|
||
|
let storeDuals = storeState
|
||
|
|
||
|
let lookupDualsM key =
|
||
|
fetchDuals >>= fun table ->
|
||
|
returnM (node_lookup key table)
|
||
|
|
||
|
let insertDualsM key value =
|
||
|
fetchDuals >>= fun table ->
|
||
|
storeDuals (node_insert key value table)
|
||
|
|
||
|
let rec visit visited vtable parent_table = function
|
||
|
[] -> (visited, parent_table)
|
||
|
| node :: rest ->
|
||
|
match node_lookup node vtable with
|
||
|
| Some _ -> visit visited vtable parent_table rest
|
||
|
| None ->
|
||
|
let children = match node with
|
||
|
| Store (v, n) -> [n]
|
||
|
| Plus l -> l
|
||
|
| Times (a, b) -> [a; b]
|
||
|
| CTimes (a, b) -> [a; b]
|
||
|
| CTimesJ (a, b) -> [a; b]
|
||
|
| Uminus x -> [x]
|
||
|
| _ -> []
|
||
|
in let rec loop t = function
|
||
|
[] -> t
|
||
|
| a :: rest ->
|
||
|
(match node_lookup a t with
|
||
|
None -> loop (node_insert a [node] t) rest
|
||
|
| Some c -> loop (node_insert a (node :: c) t) rest)
|
||
|
in
|
||
|
(visit
|
||
|
(node :: visited)
|
||
|
(node_insert node () vtable)
|
||
|
(loop parent_table children)
|
||
|
(children @ rest))
|
||
|
|
||
|
let make_transposer parent_table =
|
||
|
let rec termM node candidate_parent =
|
||
|
match candidate_parent with
|
||
|
| Store (_, n) when n == node ->
|
||
|
dualM candidate_parent >>= fun x' -> returnM [x']
|
||
|
| Plus (l) when List.memq node l ->
|
||
|
dualM candidate_parent >>= fun x' -> returnM [x']
|
||
|
| Times (a, b) when b == node ->
|
||
|
dualM candidate_parent >>= fun x' ->
|
||
|
returnM [makeTimes (a, x')]
|
||
|
| CTimes (a, b) when b == node ->
|
||
|
dualM candidate_parent >>= fun x' ->
|
||
|
returnM [CTimes (a, x')]
|
||
|
| CTimesJ (a, b) when b == node ->
|
||
|
dualM candidate_parent >>= fun x' ->
|
||
|
returnM [CTimesJ (a, x')]
|
||
|
| Uminus n when n == node ->
|
||
|
dualM candidate_parent >>= fun x' ->
|
||
|
returnM [makeUminus x']
|
||
|
| _ -> returnM []
|
||
|
|
||
|
and dualExpressionM this_node =
|
||
|
mapM (termM this_node)
|
||
|
(match node_lookup this_node parent_table with
|
||
|
| Some a -> a
|
||
|
| None -> failwith "bug in dualExpressionM"
|
||
|
) >>= fun l ->
|
||
|
returnM (makePlus (List.flatten l))
|
||
|
|
||
|
and dualM this_node =
|
||
|
memoizing lookupDualsM insertDualsM
|
||
|
(function
|
||
|
| Load v as x ->
|
||
|
if (Variable.is_constant v) then
|
||
|
returnM (Load v)
|
||
|
else
|
||
|
(dualExpressionM x >>= fun d ->
|
||
|
returnM (Store (v, d)))
|
||
|
| Store (v, x) -> returnM (Load v)
|
||
|
| x -> dualExpressionM x)
|
||
|
this_node
|
||
|
|
||
|
in dualM
|
||
|
|
||
|
let is_store = function
|
||
|
| Store _ -> true
|
||
|
| _ -> false
|
||
|
|
||
|
let transpose dag =
|
||
|
let _ = Util.info "begin transpose" in
|
||
|
let (all_nodes, parent_table) =
|
||
|
visit [] Assoctable.empty Assoctable.empty dag in
|
||
|
let transposerM = make_transposer parent_table in
|
||
|
let mapTransposerM = mapM transposerM in
|
||
|
let duals = runM Assoctable.empty mapTransposerM all_nodes in
|
||
|
let roots = List.filter is_store duals in
|
||
|
let _ = Util.info "end transpose" in
|
||
|
roots
|
||
|
end
|
||
|
|
||
|
|
||
|
(*************************************************************
|
||
|
* Various dag statistics
|
||
|
*************************************************************)
|
||
|
module Stats : sig
|
||
|
type complexity
|
||
|
val complexity : Expr.expr list -> complexity
|
||
|
val same_complexity : complexity -> complexity -> bool
|
||
|
val leq_complexity : complexity -> complexity -> bool
|
||
|
val to_string : complexity -> string
|
||
|
end = struct
|
||
|
type complexity = int * int * int * int * int * int
|
||
|
let rec visit visited vtable = function
|
||
|
[] -> visited
|
||
|
| node :: rest ->
|
||
|
match node_lookup node vtable with
|
||
|
Some _ -> visit visited vtable rest
|
||
|
| None ->
|
||
|
let children = match node with
|
||
|
Store (v, n) -> [n]
|
||
|
| Plus l -> l
|
||
|
| Times (a, b) -> [a; b]
|
||
|
| Uminus x -> [x]
|
||
|
| _ -> []
|
||
|
in visit (node :: visited)
|
||
|
(node_insert node () vtable)
|
||
|
(children @ rest)
|
||
|
|
||
|
let complexity dag =
|
||
|
let rec loop (load, store, plus, times, uminus, num) = function
|
||
|
[] -> (load, store, plus, times, uminus, num)
|
||
|
| node :: rest ->
|
||
|
loop
|
||
|
(match node with
|
||
|
| Load _ -> (load + 1, store, plus, times, uminus, num)
|
||
|
| Store _ -> (load, store + 1, plus, times, uminus, num)
|
||
|
| Plus x -> (load, store, plus + (List.length x - 1), times, uminus, num)
|
||
|
| Times _ -> (load, store, plus, times + 1, uminus, num)
|
||
|
| Uminus _ -> (load, store, plus, times, uminus + 1, num)
|
||
|
| Num _ -> (load, store, plus, times, uminus, num + 1)
|
||
|
| CTimes _ -> (load, store, plus, times, uminus, num)
|
||
|
| CTimesJ _ -> (load, store, plus, times, uminus, num)
|
||
|
| NaN _ -> (load, store, plus, times, uminus, num))
|
||
|
rest
|
||
|
in let (l, s, p, t, u, n) =
|
||
|
loop (0, 0, 0, 0, 0, 0) (visit [] Assoctable.empty dag)
|
||
|
in (l, s, p, t, u, n)
|
||
|
|
||
|
let weight (l, s, p, t, u, n) =
|
||
|
l + s + 10 * p + 20 * t + u + n
|
||
|
|
||
|
let same_complexity a b = weight a = weight b
|
||
|
let leq_complexity a b = weight a <= weight b
|
||
|
|
||
|
let to_string (l, s, p, t, u, n) =
|
||
|
Printf.sprintf "ld=%d st=%d add=%d mul=%d uminus=%d num=%d\n"
|
||
|
l s p t u n
|
||
|
|
||
|
end
|
||
|
|
||
|
(* simplify the dag *)
|
||
|
let algsimp v =
|
||
|
let rec simplification_loop v =
|
||
|
let () = Util.info "simplification step" in
|
||
|
let complexity = Stats.complexity v in
|
||
|
let () = Util.info ("complexity = " ^ (Stats.to_string complexity)) in
|
||
|
let v = (AlgSimp.algsimp @@ Transpose.transpose @@
|
||
|
AlgSimp.algsimp @@ Transpose.transpose) v in
|
||
|
let complexity' = Stats.complexity v in
|
||
|
let () = Util.info ("complexity = " ^ (Stats.to_string complexity')) in
|
||
|
if (Stats.leq_complexity complexity' complexity) then
|
||
|
let () = Util.info "end algsimp" in
|
||
|
v
|
||
|
else
|
||
|
simplification_loop v
|
||
|
|
||
|
in
|
||
|
let () = Util.info "begin algsimp" in
|
||
|
let v = AlgSimp.algsimp v in
|
||
|
if !Magic.network_transposition then simplification_loop v else v
|
||
|
|