Decrease WriteJacobian complexity

This commit is contained in:
EvilSpirit 2017-05-12 09:39:07 +07:00 committed by phkahler
parent 4d58f95b43
commit 2f31673708
3 changed files with 35 additions and 22 deletions

View File

@ -400,15 +400,21 @@ Expr *Expr::PartialWrt(hParam p) const {
ssassert(false, "Unexpected operation"); ssassert(false, "Unexpected operation");
} }
uint64_t Expr::ParamsUsed() const { void Expr::ParamsUsedList(List<hParam> *list) const {
uint64_t r = 0; if(op == Op::PARAM || op == Op::PARAM_PTR) {
if(op == Op::PARAM) r |= ((uint64_t)1 << (parh.v % 61)); hParam param = (op == Op::PARAM) ? parh : parp->h;
if(op == Op::PARAM_PTR) r |= ((uint64_t)1 << (parp->h.v % 61)); for(hParam &p : *list) {
if(p.v == param.v) return;
}
list->Add(&param);
return;
}
int c = Children(); int c = Children();
if(c >= 1) r |= a->ParamsUsed(); if(c >= 1) {
if(c >= 2) r |= b->ParamsUsed(); a->ParamsUsedList(list);
return r; if(c >= 2) b->ParamsUsedList(list);
}
} }
bool Expr::DependsOn(hParam p) const { bool Expr::DependsOn(hParam p) const {

View File

@ -70,7 +70,7 @@ public:
Expr *PartialWrt(hParam p) const; Expr *PartialWrt(hParam p) const;
double Eval() const; double Eval() const;
uint64_t ParamsUsed() const; void ParamsUsedList(List<hParam> *list) const;
bool DependsOn(hParam p) const; bool DependsOn(hParam p) const;
static bool Tol(double a, double b); static bool Tol(double a, double b);
Expr *FoldConstants(); Expr *FoldConstants();

View File

@ -33,8 +33,14 @@ bool System::WriteJacobian(int tag) {
} }
mat.n = j; mat.n = j;
int i = 0; // Fill the param id to index map
std::map<uint32_t, int> paramToIndex;
for(int j = 0; j < mat.n; j++) {
paramToIndex[mat.param[j].v] = j;
}
int i = 0;
Expr *zero = Expr::From(0.0);
for(auto &e : eq) { for(auto &e : eq) {
if(i >= MAX_UNKNOWNS) return false; if(i >= MAX_UNKNOWNS) return false;
@ -45,21 +51,22 @@ bool System::WriteJacobian(int tag) {
Expr *f = e.e->DeepCopyWithParamsAsPointers(&param, &(SK.param)); Expr *f = e.e->DeepCopyWithParamsAsPointers(&param, &(SK.param));
f = f->FoldConstants(); f = f->FoldConstants();
// Hash table (61 bits) to accelerate generation of zero partials.
uint64_t scoreboard = f->ParamsUsed();
for(j = 0; j < mat.n; j++) { for(j = 0; j < mat.n; j++) {
Expr *pd; mat.A.sym[i][j] = zero;
if(scoreboard & ((uint64_t)1 << (mat.param[j].v % 61)) && }
f->DependsOn(mat.param[j]))
{ List<hParam> paramsUsed = {};
pd = f->PartialWrt(mat.param[j]); f->ParamsUsedList(&paramsUsed);
for(hParam &p : paramsUsed) {
auto j = paramToIndex.find(p.v);
if(j == paramToIndex.end()) continue;
Expr *pd = f->PartialWrt(p);
pd = pd->FoldConstants(); pd = pd->FoldConstants();
pd = pd->DeepCopyWithParamsAsPointers(&param, &(SK.param)); pd = pd->DeepCopyWithParamsAsPointers(&param, &(SK.param));
} else { mat.A.sym[i][j->second] = pd;
pd = Expr::From(0.0);
}
mat.A.sym[i][j] = pd;
} }
paramsUsed.Clear();
mat.B.sym[i] = f; mat.B.sym[i] = f;
i++; i++;
} }