Decrease WriteJacobian complexity
This commit is contained in:
parent
4d58f95b43
commit
2f31673708
20
src/expr.cpp
20
src/expr.cpp
@ -400,15 +400,21 @@ Expr *Expr::PartialWrt(hParam p) const {
|
|||||||
ssassert(false, "Unexpected operation");
|
ssassert(false, "Unexpected operation");
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t Expr::ParamsUsed() const {
|
void Expr::ParamsUsedList(List<hParam> *list) const {
|
||||||
uint64_t r = 0;
|
if(op == Op::PARAM || op == Op::PARAM_PTR) {
|
||||||
if(op == Op::PARAM) r |= ((uint64_t)1 << (parh.v % 61));
|
hParam param = (op == Op::PARAM) ? parh : parp->h;
|
||||||
if(op == Op::PARAM_PTR) r |= ((uint64_t)1 << (parp->h.v % 61));
|
for(hParam &p : *list) {
|
||||||
|
if(p.v == param.v) return;
|
||||||
|
}
|
||||||
|
list->Add(¶m);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
int c = Children();
|
int c = Children();
|
||||||
if(c >= 1) r |= a->ParamsUsed();
|
if(c >= 1) {
|
||||||
if(c >= 2) r |= b->ParamsUsed();
|
a->ParamsUsedList(list);
|
||||||
return r;
|
if(c >= 2) b->ParamsUsedList(list);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Expr::DependsOn(hParam p) const {
|
bool Expr::DependsOn(hParam p) const {
|
||||||
|
@ -70,7 +70,7 @@ public:
|
|||||||
|
|
||||||
Expr *PartialWrt(hParam p) const;
|
Expr *PartialWrt(hParam p) const;
|
||||||
double Eval() const;
|
double Eval() const;
|
||||||
uint64_t ParamsUsed() const;
|
void ParamsUsedList(List<hParam> *list) const;
|
||||||
bool DependsOn(hParam p) const;
|
bool DependsOn(hParam p) const;
|
||||||
static bool Tol(double a, double b);
|
static bool Tol(double a, double b);
|
||||||
Expr *FoldConstants();
|
Expr *FoldConstants();
|
||||||
|
@ -33,8 +33,14 @@ bool System::WriteJacobian(int tag) {
|
|||||||
}
|
}
|
||||||
mat.n = j;
|
mat.n = j;
|
||||||
|
|
||||||
int i = 0;
|
// Fill the param id to index map
|
||||||
|
std::map<uint32_t, int> paramToIndex;
|
||||||
|
for(int j = 0; j < mat.n; j++) {
|
||||||
|
paramToIndex[mat.param[j].v] = j;
|
||||||
|
}
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
Expr *zero = Expr::From(0.0);
|
||||||
for(auto &e : eq) {
|
for(auto &e : eq) {
|
||||||
if(i >= MAX_UNKNOWNS) return false;
|
if(i >= MAX_UNKNOWNS) return false;
|
||||||
|
|
||||||
@ -45,21 +51,22 @@ bool System::WriteJacobian(int tag) {
|
|||||||
Expr *f = e.e->DeepCopyWithParamsAsPointers(¶m, &(SK.param));
|
Expr *f = e.e->DeepCopyWithParamsAsPointers(¶m, &(SK.param));
|
||||||
f = f->FoldConstants();
|
f = f->FoldConstants();
|
||||||
|
|
||||||
// Hash table (61 bits) to accelerate generation of zero partials.
|
|
||||||
uint64_t scoreboard = f->ParamsUsed();
|
|
||||||
for(j = 0; j < mat.n; j++) {
|
for(j = 0; j < mat.n; j++) {
|
||||||
Expr *pd;
|
mat.A.sym[i][j] = zero;
|
||||||
if(scoreboard & ((uint64_t)1 << (mat.param[j].v % 61)) &&
|
}
|
||||||
f->DependsOn(mat.param[j]))
|
|
||||||
{
|
List<hParam> paramsUsed = {};
|
||||||
pd = f->PartialWrt(mat.param[j]);
|
f->ParamsUsedList(¶msUsed);
|
||||||
|
|
||||||
|
for(hParam &p : paramsUsed) {
|
||||||
|
auto j = paramToIndex.find(p.v);
|
||||||
|
if(j == paramToIndex.end()) continue;
|
||||||
|
Expr *pd = f->PartialWrt(p);
|
||||||
pd = pd->FoldConstants();
|
pd = pd->FoldConstants();
|
||||||
pd = pd->DeepCopyWithParamsAsPointers(¶m, &(SK.param));
|
pd = pd->DeepCopyWithParamsAsPointers(¶m, &(SK.param));
|
||||||
} else {
|
mat.A.sym[i][j->second] = pd;
|
||||||
pd = Expr::From(0.0);
|
|
||||||
}
|
|
||||||
mat.A.sym[i][j] = pd;
|
|
||||||
}
|
}
|
||||||
|
paramsUsed.Clear();
|
||||||
mat.B.sym[i] = f;
|
mat.B.sym[i] = f;
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user