/usr/lib/python3/dist-packages/FIAT/restricted.py is in python3-fiat 2017.2.0.0-2.
This file is owned by root:root, with mode 0o644.
The actual contents of the file can be viewed below.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | # Copyright (C) 2015-2016 Jan Blechta, Andrew T T McRae, and others
#
# This file is part of FIAT.
#
# FIAT is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# FIAT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with FIAT. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import, print_function, division
import six
from six import string_types
from six import iteritems
from FIAT.dual_set import DualSet
from FIAT.finite_element import CiarletElement
class RestrictedElement(CiarletElement):
"""Restrict given element to specified list of dofs."""
def __init__(self, element, indices=None, restriction_domain=None):
'''For sake of argument, indices overrides restriction_domain'''
if not (indices or restriction_domain):
raise RuntimeError("Either indices or restriction_domain must be passed in")
if not indices:
indices = _get_indices(element, restriction_domain)
if isinstance(indices, string_types):
raise RuntimeError("variable 'indices' was a string; did you forget to use a keyword?")
if len(indices) == 0:
raise ValueError("No point in creating empty RestrictedElement.")
self._element = element
self._indices = indices
# Fetch reference element
ref_el = element.get_reference_element()
# Restrict primal set
poly_set = element.get_nodal_basis().take(indices)
# Restrict dual set
dof_counter = 0
entity_ids = {}
nodes = []
nodes_old = element.dual_basis()
for d, entities in six.iteritems(element.entity_dofs()):
entity_ids[d] = {}
for entity, dofs in six.iteritems(entities):
entity_ids[d][entity] = []
for dof in dofs:
if dof not in indices:
continue
entity_ids[d][entity].append(dof_counter)
dof_counter += 1
nodes.append(nodes_old[dof])
assert dof_counter == len(indices)
dual = DualSet(nodes, ref_el, entity_ids)
# Restrict mapping
mapping_old = element.mapping()
mapping_new = [mapping_old[dof] for dof in indices]
assert all(e_mapping == mapping_new[0] for e_mapping in mapping_new)
# Call constructor of CiarletElement
super(RestrictedElement, self).__init__(poly_set, dual, 0, element.get_formdegree(), mapping_new[0])
def sorted_by_key(mapping):
"Sort dict items by key, allowing different key types."
# Python3 doesn't allow comparing builtins of different type, therefore the typename trick here
def _key(x):
return (type(x[0]).__name__, x[0])
return sorted(iteritems(mapping), key=_key)
def _get_indices(element, restriction_domain):
"Restriction domain can be 'interior', 'vertex', 'edge', 'face' or 'facet'"
if restriction_domain == "interior":
# Return dofs from interior
return element.entity_dofs()[max(element.entity_dofs().keys())][0]
# otherwise return dofs with d <= dim
if restriction_domain == "vertex":
dim = 0
elif restriction_domain == "edge":
dim = 1
elif restriction_domain == "face":
dim = 2
elif restriction_domain == "facet":
dim = element.get_reference_element().get_spatial_dimension() - 1
else:
raise RuntimeError("Invalid restriction domain")
is_prodcell = isinstance(max(element.entity_dofs().keys()), tuple)
entity_dofs = element.entity_dofs()
indices = []
for d in range(dim + 1):
if is_prodcell:
for a in range(d + 1):
b = d - a
try:
entities = entity_dofs[(a, b)]
for (entity, index) in sorted_by_key(entities):
indices += index
except KeyError:
pass
else:
entities = entity_dofs[d]
for (entity, index) in sorted_by_key(entities):
indices += index
return indices
|