@@ -569,7 +569,7 @@ module type CircuitInterface = sig
569569 val circuit_is_free : circuit -> bool
570570
571571 (* Direct circuuit constructions *)
572- val circuit_ite : c :circuit -> t :circuit -> f :circuit -> circuit
572+ val circuit_ite : ? strict : bool -> c :circuit -> t :circuit -> f :circuit -> circuit
573573 val circuit_eq : circuit -> circuit -> circuit
574574 val circuit_eqs : circuit -> circuit -> circuit list
575575
@@ -999,18 +999,22 @@ module MakeCircuitInterfaceFromCBackend(Backend: CBackend) : CircuitInterface =
999999
10001000 let circuit_is_free (f : circuit ) : bool = List. is_empty @@ snd f
10011001
1002- let circuit_ite ~(c : circuit ) ~(t : circuit ) ~(f : circuit ) : circuit =
1003- assert ((circuit_is_free t) && (circuit_is_free f) && (circuit_is_free c));
1002+ let circuit_ite ?(strict = false ) ~(c : circuit ) ~(t : circuit ) ~(f : circuit ) : circuit =
1003+ let inps = match c, t, f with
1004+ | (_ , [] ), (_ , [] ), (_ , [] ) when strict -> []
1005+ | (_ , cinps ), (_ , tinps ), (_ , finps ) when (not strict) && cinps = tinps && cinps = finps -> cinps
1006+ | _ -> assert false
1007+ in
10041008 let c = match (fst c).type_ with
10051009 | CBool -> Backend. node_of_reg (fst c).reg
10061010 | _ -> assert false
10071011 in
10081012 let res_r = Backend. reg_ite c (fst t).reg (fst f).reg in
10091013 match ((fst t).type_, (fst f).type_) with
1010- | CBitstring nt , CBitstring nf when nt = nf -> {reg = res_r; type_ = (fst t).type_}, []
1011- | CArray {width =wt ; count =nt } , CArray {width =wf ; count =nf } when wt = wf && nt = nf -> {reg = res_r; type_ = (fst t).type_}, []
1012- | CTuple szs_t , CTuple szs_f when List. all2 (= ) szs_t szs_f -> {reg = res_r; type_ = (fst t).type_}, []
1013- | CBool , CBool -> {reg = res_r; type_ = (fst t).type_}, []
1014+ | CBitstring nt , CBitstring nf when nt = nf -> {reg = res_r; type_ = (fst t).type_}, inps
1015+ | CArray {width =wt ; count =nt } , CArray {width =wf ; count =nf } when wt = wf && nt = nf -> {reg = res_r; type_ = (fst t).type_}, inps
1016+ | CTuple szs_t , CTuple szs_f when List. all2 (= ) szs_t szs_f -> {reg = res_r; type_ = (fst t).type_}, inps
1017+ | CBool , CBool -> {reg = res_r; type_ = (fst t).type_}, inps
10141018 | _ -> raise CircConstructorInvalidArguments
10151019
10161020 (* TODO: type check? *)
0 commit comments