Decrease WriteJacobian complexity

pull/1159/head
EvilSpirit 2017-05-12 09:39:07 +07:00 committed by phkahler
parent 4d58f95b43
commit 2f31673708
3 changed files with 35 additions and 22 deletions

View File

@ -400,15 +400,21 @@ Expr *Expr::PartialWrt(hParam p) const {
ssassert(false, "Unexpected operation");
}
uint64_t Expr::ParamsUsed() const {
uint64_t r = 0;
if(op == Op::PARAM) r |= ((uint64_t)1 << (parh.v % 61));
if(op == Op::PARAM_PTR) r |= ((uint64_t)1 << (parp->h.v % 61));
void Expr::ParamsUsedList(List<hParam> *list) const {
if(op == Op::PARAM || op == Op::PARAM_PTR) {
hParam param = (op == Op::PARAM) ? parh : parp->h;
for(hParam &p : *list) {
if(p.v == param.v) return;
}
list->Add(&param);
return;
}
int c = Children();
if(c >= 1) r |= a->ParamsUsed();
if(c >= 2) r |= b->ParamsUsed();
return r;
if(c >= 1) {
a->ParamsUsedList(list);
if(c >= 2) b->ParamsUsedList(list);
}
}
bool Expr::DependsOn(hParam p) const {

View File

@ -70,7 +70,7 @@ public:
Expr *PartialWrt(hParam p) const;
double Eval() const;
uint64_t ParamsUsed() const;
void ParamsUsedList(List<hParam> *list) const;
bool DependsOn(hParam p) const;
static bool Tol(double a, double b);
Expr *FoldConstants();

View File

@ -33,8 +33,14 @@ bool System::WriteJacobian(int tag) {
}
mat.n = j;
int i = 0;
// Fill the param id to index map
std::map<uint32_t, int> paramToIndex;
for(int j = 0; j < mat.n; j++) {
paramToIndex[mat.param[j].v] = j;
}
int i = 0;
Expr *zero = Expr::From(0.0);
for(auto &e : eq) {
if(i >= MAX_UNKNOWNS) return false;
@ -45,21 +51,22 @@ bool System::WriteJacobian(int tag) {
Expr *f = e.e->DeepCopyWithParamsAsPointers(&param, &(SK.param));
f = f->FoldConstants();
// Hash table (61 bits) to accelerate generation of zero partials.
uint64_t scoreboard = f->ParamsUsed();
for(j = 0; j < mat.n; j++) {
Expr *pd;
if(scoreboard & ((uint64_t)1 << (mat.param[j].v % 61)) &&
f->DependsOn(mat.param[j]))
{
pd = f->PartialWrt(mat.param[j]);
pd = pd->FoldConstants();
pd = pd->DeepCopyWithParamsAsPointers(&param, &(SK.param));
} else {
pd = Expr::From(0.0);
}
mat.A.sym[i][j] = pd;
mat.A.sym[i][j] = zero;
}
List<hParam> paramsUsed = {};
f->ParamsUsedList(&paramsUsed);
for(hParam &p : paramsUsed) {
auto j = paramToIndex.find(p.v);
if(j == paramToIndex.end()) continue;
Expr *pd = f->PartialWrt(p);
pd = pd->FoldConstants();
pd = pd->DeepCopyWithParamsAsPointers(&param, &(SK.param));
mat.A.sym[i][j->second] = pd;
}
paramsUsed.Clear();
mat.B.sym[i] = f;
i++;
}