Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
dune-istl
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Core Modules
dune-istl
Commits
82263ea9
Commit
82263ea9
authored
16 years ago
by
Markus Blatt
Browse files
Options
Downloads
Patches
Plain Diff
Sparse matrix matrix multiplication due to popular demand.
[[Imported from SVN: r1000]]
parent
0f0b8760
Branches
Branches containing commit
Tags
Tags containing commit
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
istl/matrixmatrix.hh
+563
-0
563 additions, 0 deletions
istl/matrixmatrix.hh
with
563 additions
and
0 deletions
istl/matrixmatrix.hh
0 → 100644
+
563
−
0
View file @
82263ea9
// -*- tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*-
// vi: set et ts=4 sw=2 sts=2:
#ifndef DUNE_MATRIXMATRIX_HH
#define DUNE_MATRIXMATRIX_HH
#include
<dune/istl/bcrsmatrix.hh>
#include
<dune/common/fmatrix.hh>
#include
<dune/common/tuples.hh>
#include
<dune/common/timer.hh>
namespace
Dune
{
/**
* @addtogroup ISTL_SPMV
*
* @{
*/
/** @file
* @author Markus Blatt
* @brief provides functions for sparse matrix matrix multiplication.
*/
namespace
{
/**
* @brief Traverses over the nonzero pattern of the matrix-matrix product.
*
* Template parameter b is used to select the matrix product:
* <dt>0</dt><dd>\f$A\cdot B$\f</dd>
* <dt>1</dt><dd>\f$A^T\cdot B$\f</dd>
* <dt>2</dt><dd>\f$A\cdot B^T\f</dd>
*/
template
<
int
b
>
struct
NonzeroPatternTraverser
{};
template
<
>
struct
NonzeroPatternTraverser
<
0
>
{
template
<
class
T
,
class
A1
,
class
A2
,
class
F
,
int
n
,
int
m
,
int
k
>
static
void
traverse
(
const
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
n
,
k
>
,
A1
>&
A
,
const
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
k
,
m
>
,
A2
>&
B
,
F
&
func
)
{
typedef
typename
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
k
,
m
>
,
A2
>::
size_type
size_type
;
if
(
A
.
M
()
!=
B
.
N
())
DUNE_THROW
(
ISTLError
,
"The sizes of the matrices do not match: "
<<
A
.
M
()
<<
"!="
<<
B
.
N
());
typedef
typename
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
n
,
k
>
,
A1
>::
ConstRowIterator
Row
;
typedef
typename
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
n
,
k
>
,
A1
>::
ConstColIterator
Col
;
typedef
typename
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
k
,
m
>
,
A2
>::
ConstRowIterator
BRow
;
typedef
typename
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
k
,
m
>
,
A2
>::
ConstColIterator
BCol
;
for
(
Row
row
=
A
.
begin
();
row
!=
A
.
end
();
++
row
)
{
// Loop over all column entries
for
(
Col
col
=
row
->
begin
();
col
!=
row
->
end
();
++
col
)
{
// entry at i,k
// search for all nonzeros in row k
for
(
BCol
bcol
=
B
[
col
.
index
()].
begin
();
bcol
!=
B
[
col
.
index
()].
end
();
++
bcol
)
{
func
(
*
col
,
*
bcol
,
row
.
index
(),
bcol
.
index
());
}
}
}
}
};
template
<
>
struct
NonzeroPatternTraverser
<
1
>
{
template
<
class
T
,
class
A1
,
class
A2
,
class
F
,
int
n
,
int
m
,
int
k
>
static
void
traverse
(
const
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
k
,
n
>
,
A1
>&
A
,
const
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
k
,
m
>
,
A2
>&
B
,
F
&
func
)
{
if
(
A
.
N
()
!=
B
.
N
())
DUNE_THROW
(
ISTLError
,
"The sizes of the matrices do not match: "
<<
A
.
N
()
<<
"!="
<<
B
.
N
());
typedef
typename
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
k
,
n
>
,
A1
>::
ConstRowIterator
Row
;
typedef
typename
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
k
,
n
>
,
A1
>::
ConstColIterator
Col
;
typedef
typename
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
k
,
m
>
,
A2
>::
ConstColIterator
BCol
;
typedef
typename
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
k
,
n
>
,
A1
>::
size_type
size_t1
;
typedef
typename
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
k
,
m
>
,
A2
>::
size_type
size_t2
;
for
(
Row
row
=
A
.
begin
();
row
!=
A
.
end
();
++
row
)
{
for
(
Col
col
=
row
->
begin
();
col
!=
row
->
end
();
++
col
)
{
for
(
BCol
bcol
=
B
[
row
.
index
()].
begin
();
bcol
!=
B
[
row
.
index
()].
end
();
++
bcol
)
{
func
(
*
col
,
*
bcol
,
col
.
index
(),
bcol
.
index
());
}
}
}
}
};
template
<
>
struct
NonzeroPatternTraverser
<
2
>
{
template
<
class
T
,
class
A1
,
class
A2
,
class
F
,
int
n
,
int
m
,
int
k
>
static
void
traverse
(
const
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
m
>
,
A1
>&
mat
,
const
BCRSMatrix
<
FieldMatrix
<
T
,
k
,
m
>
,
A2
>&
matt
,
F
&
func
)
{
if
(
mat
.
M
()
!=
matt
.
M
())
DUNE_THROW
(
ISTLError
,
"The sizes of the matrices do not match: "
<<
mat
.
N
()
<<
"!="
<<
matt
.
N
());
typedef
typename
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
m
>
,
A1
>::
ConstRowIterator
row_iterator
;
typedef
typename
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
m
>
,
A1
>::
ConstColIterator
col_iterator
;
typedef
typename
BCRSMatrix
<
FieldMatrix
<
T
,
k
,
m
>
,
A2
>::
ConstRowIterator
row_iterator_t
;
typedef
typename
BCRSMatrix
<
FieldMatrix
<
T
,
k
,
m
>
,
A2
>::
ConstColIterator
col_iterator_t
;
for
(
row_iterator
mrow
=
mat
.
begin
();
mrow
!=
mat
.
end
();
++
mrow
)
{
//iterate over the column entries
// mt is a transposed matrix crs therefore it is treated as a ccs matrix
// and the row_iterator iterates over the columns of the transposed matrix.
// search the row of the transposed matrix for an entry with the same index
// as the mcol iterator
for
(
row_iterator_t
mtcol
=
matt
.
begin
();
mtcol
!=
matt
.
end
();
++
mtcol
,
func
.
nextCol
())
{
//Search for col entries in mat that have a corrsponding row index in matt
// (i.e. corresponding col index in the as this is the transposed matrix
col_iterator_t
mtrow
=
mtcol
->
begin
();
for
(
col_iterator
mcol
=
mrow
->
begin
();
mcol
!=
mrow
->
end
();
++
mcol
)
{
// search
// TODO: This should probably be substituted by a binary search
for
(
;
mtrow
!=
mtcol
->
end
();
++
mtrow
)
if
(
mtrow
.
index
()
>=
mcol
.
index
())
break
;
if
(
mtrow
!=
mtcol
->
end
()
&&
mtrow
.
index
()
==
mcol
.
index
())
{
func
(
*
mcol
,
*
mtrow
,
mtcol
.
index
());
// In some cases we only search for one pair, then we break here
// and continue with the next column.
if
(
F
::
do_break
)
break
;
}
}
}
func
.
nextRow
();
}
}
};
template
<
class
T
,
class
A
,
int
n
,
int
m
>
class
SparsityPatternInitializer
{
public:
enum
{
do_break
=
true
};
typedef
typename
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
m
>
,
A
>::
CreateIterator
CreateIterator
;
typedef
typename
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
m
>
,
A
>::
size_type
size_type
;
SparsityPatternInitializer
(
CreateIterator
iter
)
:
rowiter
(
iter
)
{}
template
<
class
T1
,
class
T2
>
void
operator
()(
const
T1
&
t1
,
const
T2
&
t2
,
size_type
j
)
{
rowiter
.
insert
(
j
);
}
void
nextRow
()
{
++
rowiter
;
}
void
nextCol
()
{}
private
:
CreateIterator
rowiter
;
};
template
<
int
transpose
,
class
T
,
class
TA
,
int
n
,
int
m
>
class
MatrixInitializer
{
public:
enum
{
do_break
=
true
};
typedef
typename
Dune
::
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
m
>
,
TA
>
Matrix
;
typedef
typename
Matrix
::
CreateIterator
CreateIterator
;
typedef
typename
Matrix
::
size_type
size_type
;
MatrixInitializer
(
Matrix
&
A_
,
size_type
rows
)
:
count
(
0
),
A
(
A_
)
{}
template
<
class
T1
,
class
T2
>
void
operator
()(
const
T1
&
t1
,
const
T2
&
t2
,
int
j
)
{
++
count
;
}
void
nextCol
()
{}
void
nextRow
()
{}
std
::
size_t
nonzeros
()
{
return
count
;
}
template
<
class
A1
,
class
A2
,
int
n2
,
int
m2
,
int
n3
,
int
m3
>
void
initPattern
(
const
BCRSMatrix
<
FieldMatrix
<
T
,
n2
,
m2
>
,
A1
>&
mat1
,
const
BCRSMatrix
<
FieldMatrix
<
T
,
n3
,
m3
>
,
A2
>&
mat2
)
{
SparsityPatternInitializer
<
T
,
TA
,
n
,
m
>
sparsity
(
A
.
createbegin
());
NonzeroPatternTraverser
<
transpose
>::
traverse
(
mat1
,
mat2
,
sparsity
);
}
private
:
std
::
size_t
count
;
Matrix
&
A
;
};
template
<
class
T
,
class
TA
,
int
n
,
int
m
>
class
MatrixInitializer
<
1
,
T
,
TA
,
n
,
m
>
{
public:
enum
{
do_break
=
false
};
typedef
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
n
,
m
>
,
TA
>
Matrix
;
typedef
typename
Matrix
::
CreateIterator
CreateIterator
;
typedef
typename
Matrix
::
size_type
size_type
;
MatrixInitializer
(
Matrix
&
A_
,
size_type
rows
)
:
A
(
A_
),
entries
(
rows
)
{}
template
<
class
T1
,
class
T2
>
void
operator
()(
const
T1
&
t1
,
const
T2
&
t2
,
size_type
i
,
size_type
j
)
{
entries
[
i
].
insert
(
j
);
}
void
nextCol
()
{}
size_type
nonzeros
()
{
size_type
nnz
=
0
;
typedef
typename
std
::
vector
<
std
::
set
<
size_t
>
>::
const_iterator
Iter
;
for
(
Iter
iter
=
entries
.
begin
();
iter
!=
entries
.
end
();
++
iter
)
nnz
+=
(
*
iter
).
size
();
return
nnz
;
}
template
<
class
A1
,
class
A2
,
int
n2
,
int
m2
,
int
n3
,
int
m3
>
void
initPattern
(
const
BCRSMatrix
<
FieldMatrix
<
T
,
n2
,
m2
>
,
A1
>&
mat1
,
const
BCRSMatrix
<
FieldMatrix
<
T
,
n3
,
m3
>
,
A2
>&
mat2
)
{
typedef
typename
std
::
vector
<
std
::
set
<
size_t
>
>::
const_iterator
Iter
;
CreateIterator
citer
=
A
.
createbegin
();
for
(
Iter
iter
=
entries
.
begin
();
iter
!=
entries
.
end
();
++
iter
,
++
citer
)
{
typedef
std
::
set
<
size_t
>::
const_iterator
SetIter
;
for
(
SetIter
index
=
iter
->
begin
();
index
!=
iter
->
end
();
++
index
)
citer
.
insert
(
*
index
);
}
}
private
:
Matrix
&
A
;
std
::
vector
<
std
::
set
<
size_t
>
>
entries
;
};
template
<
class
T
,
class
TA
,
int
n
,
int
m
>
struct
MatrixInitializer
<
0
,
T
,
TA
,
n
,
m
>
:
public
MatrixInitializer
<
1
,
T
,
TA
,
n
,
m
>
{
MatrixInitializer
(
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
n
,
m
>
,
TA
>&
A_
,
typename
Dune
::
BCRSMatrix
<
Dune
::
FieldMatrix
<
T
,
n
,
m
>
,
TA
>::
size_type
rows
)
:
MatrixInitializer
<
1
,
T
,
TA
,
n
,
m
>
(
A_
,
rows
)
{}
};
template
<
class
T
,
class
T1
,
class
T2
,
int
n
,
int
m
,
int
k
>
void
addMatMultTransposeMat
(
FieldMatrix
<
T
,
n
,
k
>&
res
,
const
FieldMatrix
<
T1
,
n
,
m
>&
mat
,
const
FieldMatrix
<
T2
,
k
,
m
>&
matt
)
{
typedef
typename
FieldMatrix
<
T
,
n
,
k
>::
size_type
size_type
;
for
(
size_type
row
=
0
;
row
<
n
;
++
row
)
for
(
size_type
col
=
0
;
col
<
k
;
++
col
)
{
for
(
size_type
i
=
0
;
i
<
m
;
++
i
)
res
[
row
][
col
]
+=
mat
[
row
][
i
]
*
matt
[
col
][
i
];
}
}
template
<
class
T
,
class
T1
,
class
T2
,
int
n
,
int
m
,
int
k
>
void
addTransposeMatMultMat
(
FieldMatrix
<
T
,
n
,
k
>&
res
,
const
FieldMatrix
<
T1
,
m
,
n
>&
mat
,
const
FieldMatrix
<
T2
,
m
,
k
>&
matt
)
{
typedef
typename
FieldMatrix
<
T
,
n
,
k
>::
size_type
size_type
;
for
(
size_type
i
=
0
;
i
<
m
;
++
i
)
for
(
size_type
row
=
0
;
row
<
n
;
++
row
)
{
for
(
size_type
col
=
0
;
col
<
k
;
++
col
)
res
[
row
][
col
]
+=
mat
[
i
][
row
]
*
matt
[
i
][
col
];
}
}
template
<
class
T
,
class
T1
,
class
T2
,
int
n
,
int
m
,
int
k
>
void
addMatMultMat
(
FieldMatrix
<
T
,
n
,
m
>&
res
,
const
FieldMatrix
<
T1
,
n
,
k
>&
mat
,
const
FieldMatrix
<
T2
,
k
,
m
>&
matt
)
{
typedef
typename
FieldMatrix
<
T
,
n
,
k
>::
size_type
size_type
;
for
(
size_type
row
=
0
;
row
<
n
;
++
row
)
for
(
size_type
col
=
0
;
col
<
m
;
++
col
)
{
for
(
size_type
i
=
0
;
i
<
k
;
++
i
)
res
[
row
][
col
]
+=
mat
[
row
][
i
]
*
matt
[
i
][
col
];
}
}
template
<
class
T
,
class
A
,
int
n
,
int
m
>
class
EntryAccumulatorFather
{
public:
enum
{
do_break
=
false
};
typedef
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
m
>
,
A
>
Matrix
;
typedef
typename
Matrix
::
RowIterator
Row
;
typedef
typename
Matrix
::
ColIterator
Col
;
EntryAccumulatorFather
(
Matrix
&
mat_
)
:
mat
(
mat_
),
row
(
mat
.
begin
())
{
mat
=
0
;
col
=
row
->
begin
();
}
void
nextRow
()
{
++
row
;
if
(
row
!=
mat
.
end
())
col
=
row
->
begin
();
}
void
nextCol
()
{
++
this
->
col
;
}
protected
:
Matrix
&
mat
;
private
:
Row
row
;
protected
:
Col
col
;
};
template
<
class
T
,
class
A
,
int
n
,
int
m
,
int
transpose
>
class
EntryAccumulator
:
public
EntryAccumulatorFather
<
T
,
A
,
n
,
m
>
{
public:
typedef
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
m
>
,
A
>
Matrix
;
typedef
typename
Matrix
::
size_type
size_type
;
EntryAccumulator
(
Matrix
&
mat_
)
:
EntryAccumulatorFather
<
T
,
A
,
n
,
m
>
(
mat_
)
{}
template
<
class
T1
,
class
T2
>
void
operator
()(
const
T1
&
t1
,
const
T2
&
t2
,
size_type
i
)
{
assert
(
this
->
col
.
index
()
==
i
);
addMatMultMat
(
*
(
this
->
col
),
t1
,
t2
);
}
};
template
<
class
T
,
class
A
,
int
n
,
int
m
>
class
EntryAccumulator
<
T
,
A
,
n
,
m
,
0
>
:
public
EntryAccumulatorFather
<
T
,
A
,
n
,
m
>
{
public:
typedef
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
m
>
,
A
>
Matrix
;
typedef
typename
Matrix
::
size_type
size_type
;
EntryAccumulator
(
Matrix
&
mat_
)
:
EntryAccumulatorFather
<
T
,
A
,
n
,
m
>
(
mat_
)
{}
template
<
class
T1
,
class
T2
>
void
operator
()(
const
T1
&
t1
,
const
T2
&
t2
,
size_type
i
,
size_type
j
)
{
addMatMultMat
(
this
->
mat
[
i
][
j
],
t1
,
t2
);
}
};
template
<
class
T
,
class
A
,
int
n
,
int
m
>
class
EntryAccumulator
<
T
,
A
,
n
,
m
,
1
>
:
public
EntryAccumulatorFather
<
T
,
A
,
n
,
m
>
{
public:
typedef
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
m
>
,
A
>
Matrix
;
typedef
typename
Matrix
::
size_type
size_type
;
EntryAccumulator
(
Matrix
&
mat_
)
:
EntryAccumulatorFather
<
T
,
A
,
n
,
m
>
(
mat_
)
{}
template
<
class
T1
,
class
T2
>
void
operator
()(
const
T1
&
t1
,
const
T2
&
t2
,
size_type
i
,
size_type
j
)
{
addTransposeMatMultMat
(
this
->
mat
[
i
][
j
],
t1
,
t2
);
}
};
template
<
class
T
,
class
A
,
int
n
,
int
m
>
class
EntryAccumulator
<
T
,
A
,
n
,
m
,
2
>
:
public
EntryAccumulatorFather
<
T
,
A
,
n
,
m
>
{
public:
typedef
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
m
>
,
A
>
Matrix
;
typedef
typename
Matrix
::
size_type
size_type
;
EntryAccumulator
(
Matrix
&
mat_
)
:
EntryAccumulatorFather
<
T
,
A
,
n
,
m
>
(
mat_
)
{}
template
<
class
T1
,
class
T2
>
void
operator
()(
const
T1
&
t1
,
const
T2
&
t2
,
size_type
i
)
{
assert
(
this
->
col
.
index
()
==
i
);
addMatMultTransposeMat
(
*
this
->
col
,
t1
,
t2
);
}
};
template
<
int
transpose
>
struct
SizeSelector
{};
template
<
>
struct
SizeSelector
<
0
>
{
template
<
class
M1
,
class
M2
>
static
tuple
<
typename
M1
::
size_type
,
typename
M2
::
size_type
>
size
(
const
M1
&
m1
,
const
M2
&
m2
)
{
return
make_tuple
(
m1
.
N
(),
m2
.
M
());
}
};
template
<
>
struct
SizeSelector
<
1
>
{
template
<
class
M1
,
class
M2
>
static
tuple
<
typename
M1
::
size_type
,
typename
M2
::
size_type
>
size
(
const
M1
&
m1
,
const
M2
&
m2
)
{
return
make_tuple
(
m1
.
M
(),
m2
.
M
());
}
};
template
<
>
struct
SizeSelector
<
2
>
{
template
<
class
M1
,
class
M2
>
static
tuple
<
typename
M1
::
size_type
,
typename
M2
::
size_type
>
size
(
const
M1
&
m1
,
const
M2
&
m2
)
{
return
make_tuple
(
m1
.
N
(),
m2
.
N
());
}
};
template
<
int
transpose
,
class
T
,
class
A
,
class
A1
,
class
A2
,
int
n1
,
int
m1
,
int
n2
,
int
m2
,
int
n3
,
int
m3
>
void
matMultMat
(
BCRSMatrix
<
FieldMatrix
<
T
,
n1
,
m1
>
,
A
>&
res
,
const
BCRSMatrix
<
FieldMatrix
<
T
,
n2
,
m2
>
,
A1
>&
mat1
,
const
BCRSMatrix
<
FieldMatrix
<
T
,
n3
,
m3
>
,
A2
>&
mat2
)
{
// First step is to count the number of nonzeros
typename
BCRSMatrix
<
FieldMatrix
<
T
,
n1
,
m1
>
,
A
>::
size_type
rows
,
cols
;
tie
(
rows
,
cols
)
=
SizeSelector
<
transpose
>::
size
(
mat1
,
mat2
);
MatrixInitializer
<
transpose
,
T
,
A1
,
n1
,
m1
>
patternInit
(
res
,
rows
);
Timer
timer
;
NonzeroPatternTraverser
<
transpose
>::
traverse
(
mat1
,
mat2
,
patternInit
);
res
.
setSize
(
rows
,
cols
,
patternInit
.
nonzeros
());
res
.
setBuildMode
(
BCRSMatrix
<
FieldMatrix
<
T
,
n1
,
m1
>
,
A1
>::
row_wise
);
std
::
cout
<<
"Counting nonzeros took "
<<
timer
.
elapsed
()
<<
std
::
endl
;
timer
.
reset
();
// Second step is to allocate the storage for the result and initialize the nonzero pattern
patternInit
.
initPattern
(
mat1
,
mat2
);
std
::
cout
<<
"Setting up sparsity pattern took "
<<
timer
.
elapsed
()
<<
std
::
endl
;
timer
.
reset
();
// As a last step calculate the entries
EntryAccumulator
<
T
,
A1
,
n1
,
m1
,
transpose
>
entriesAccu
(
res
);
NonzeroPatternTraverser
<
transpose
>::
traverse
(
mat1
,
mat2
,
entriesAccu
);
std
::
cout
<<
"Calculating entries took "
<<
timer
.
elapsed
()
<<
std
::
endl
;
}
}
/**
* @brief Helper TMP to get the result type of a sparse matrix matrix multiplication (C=A*B)
*
* The type of matrix C will be stored as the associated type MatMultMatResult::type.
* @tparam M1 The type of matrix A.
* @tparam M2 The type of matrix B.
*/
template
<
typename
M1
,
typename
M2
>
struct
MatMultMatResult
{};
template
<
typename
T
,
int
n
,
int
k
,
int
m
>
struct
MatMultMatResult
<
FieldMatrix
<
T
,
n
,
k
>
,
FieldMatrix
<
T
,
k
,
m
>
>
{
typedef
FieldMatrix
<
T
,
n
,
m
>
type
;
};
template
<
typename
T
,
typename
A
,
int
n
,
int
k
,
int
m
>
struct
MatMultMatResult
<
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
k
>
,
A
>
,
BCRSMatrix
<
FieldMatrix
<
T
,
k
,
m
>
,
A
>
>
{
typedef
BCRSMatrix
<
typename
MatMultMatResult
<
FieldMatrix
<
T
,
n
,
k
>
,
FieldMatrix
<
T
,
k
,
m
>
>::
type
,
A
>
type
;
};
/**
* @brief Calculate product of a sparse matrix with a transposed sparse matrices (\fC=A*B^T\f).
*
* @param res Matrix for the result of the computation.
* @param mat Matrix A.
* @param matt Matrix B, which will be transposed before the multiplication.
*/
template
<
class
T
,
class
A
,
class
A1
,
class
A2
,
int
n
,
int
m
,
int
k
>
void
matMultTransposeMat
(
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
k
>
,
A
>&
res
,
const
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
m
>
,
A1
>&
mat
,
const
BCRSMatrix
<
FieldMatrix
<
T
,
k
,
m
>
,
A2
>&
matt
,
bool
tryHard
=
false
)
{
matMultMat
<
2
>
(
res
,
mat
,
matt
);
}
/**
* @brief Calculate product of two sparse matrices (C=A*B).
*
* @param res Matrix for the result of the computation.
* @param mat Matrix A.
* @param matt Matrix B.
*/
template
<
class
T
,
class
A
,
class
A1
,
class
A2
,
int
n
,
int
m
,
int
k
>
void
matMultMat
(
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
m
>
,
A
>&
res
,
const
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
k
>
,
A1
>&
mat
,
const
BCRSMatrix
<
FieldMatrix
<
T
,
k
,
m
>
,
A2
>&
matt
,
bool
tryHard
=
false
)
{
matMultMat
<
0
>
(
res
,
mat
,
matt
);
}
/**
* @brief Calculate product of a transposed sparse matrix with another sparse matrices (\fC=A^T*B\f).
*
* @param res Matrix for the result of the computation.
* @param mat Matrix A, which will be transposed before the multiplication.
* @param matt Matrix B.
*/
template
<
class
T
,
class
A
,
class
A1
,
class
A2
,
int
n
,
int
m
,
int
k
>
void
transposeMatMultMat
(
BCRSMatrix
<
FieldMatrix
<
T
,
n
,
m
>
,
A
>&
res
,
const
BCRSMatrix
<
FieldMatrix
<
T
,
k
,
n
>
,
A1
>&
mat
,
const
BCRSMatrix
<
FieldMatrix
<
T
,
k
,
m
>
,
A2
>&
matt
,
bool
tryHard
=
false
)
{
matMultMat
<
1
>
(
res
,
mat
,
matt
);
}
}
#endif
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment