filestore.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. """
  2. This module contains an C{L{OpenIDStore}} implementation backed by
  3. flat files.
  4. """
  5. import string
  6. import os
  7. import os.path
  8. import time
  9. import logging
  10. from errno import EEXIST, ENOENT
  11. from tempfile import mkstemp
  12. from openid.association import Association
  13. from openid.store.interface import OpenIDStore
  14. from openid.store import nonce
  15. from openid import cryptutil, oidutil
  16. logger = logging.getLogger(__name__)
  17. _filename_allowed = string.ascii_letters + string.digits + '.'
  18. _isFilenameSafe = set(_filename_allowed).__contains__
  19. def _safe64(s):
  20. h64 = oidutil.toBase64(cryptutil.sha1(s))
  21. # to be able to manipulate it, make it a bytearray
  22. h64 = bytearray(h64)
  23. h64 = h64.replace(b'+', b'_')
  24. h64 = h64.replace(b'/', b'.')
  25. h64 = h64.replace(b'=', b'')
  26. return bytes(h64)
  27. def _filenameEscape(s):
  28. filename_chunks = []
  29. for c in s:
  30. if _isFilenameSafe(c):
  31. filename_chunks.append(c)
  32. else:
  33. filename_chunks.append('_%02X' % ord(c))
  34. return ''.join(filename_chunks)
  35. def _removeIfPresent(filename):
  36. """Attempt to remove a file, returning whether the file existed at
  37. the time of the call.
  38. str -> bool
  39. """
  40. try:
  41. os.unlink(filename)
  42. except OSError as why:
  43. if why.errno == ENOENT:
  44. # Someone beat us to it, but it's gone, so that's OK
  45. return 0
  46. else:
  47. raise
  48. else:
  49. # File was present
  50. return 1
  51. def _ensureDir(dir_name):
  52. """Create dir_name as a directory if it does not exist. If it
  53. exists, make sure that it is, in fact, a directory.
  54. Can raise OSError
  55. str -> NoneType
  56. """
  57. try:
  58. os.makedirs(dir_name)
  59. except OSError as why:
  60. if why.errno != EEXIST or not os.path.isdir(dir_name):
  61. raise
  62. class FileOpenIDStore(OpenIDStore):
  63. """
  64. This is a filesystem-based store for OpenID associations and
  65. nonces. This store should be safe for use in concurrent systems
  66. on both windows and unix (excluding NFS filesystems). There are a
  67. couple race conditions in the system, but those failure cases have
  68. been set up in such a way that the worst-case behavior is someone
  69. having to try to log in a second time.
  70. Most of the methods of this class are implementation details.
  71. People wishing to just use this store need only pay attention to
  72. the C{L{__init__}} method.
  73. Methods of this object can raise OSError if unexpected filesystem
  74. conditions, such as bad permissions or missing directories, occur.
  75. """
  76. def __init__(self, directory):
  77. """
  78. Initializes a new FileOpenIDStore. This initializes the
  79. nonce and association directories, which are subdirectories of
  80. the directory passed in.
  81. @param directory: This is the directory to put the store
  82. directories in.
  83. @type directory: C{str}
  84. """
  85. # Make absolute
  86. directory = os.path.normpath(os.path.abspath(directory))
  87. self.nonce_dir = os.path.join(directory, 'nonces')
  88. self.association_dir = os.path.join(directory, 'associations')
  89. # Temp dir must be on the same filesystem as the assciations
  90. # directory
  91. self.temp_dir = os.path.join(directory, 'temp')
  92. self.max_nonce_age = 6 * 60 * 60 # Six hours, in seconds
  93. self._setup()
  94. def _setup(self):
  95. """Make sure that the directories in which we store our data
  96. exist.
  97. () -> NoneType
  98. """
  99. _ensureDir(self.nonce_dir)
  100. _ensureDir(self.association_dir)
  101. _ensureDir(self.temp_dir)
  102. def _mktemp(self):
  103. """Create a temporary file on the same filesystem as
  104. self.association_dir.
  105. The temporary directory should not be cleaned if there are any
  106. processes using the store. If there is no active process using
  107. the store, it is safe to remove all of the files in the
  108. temporary directory.
  109. () -> (file, str)
  110. """
  111. fd, name = mkstemp(dir=self.temp_dir)
  112. try:
  113. file_obj = os.fdopen(fd, 'wb')
  114. return file_obj, name
  115. except:
  116. _removeIfPresent(name)
  117. raise
  118. def getAssociationFilename(self, server_url, handle):
  119. """Create a unique filename for a given server url and
  120. handle. This implementation does not assume anything about the
  121. format of the handle. The filename that is returned will
  122. contain the domain name from the server URL for ease of human
  123. inspection of the data directory.
  124. (str, str) -> str
  125. """
  126. if server_url.find('://') == -1:
  127. raise ValueError('Bad server URL: %r' % server_url)
  128. proto, rest = server_url.split('://', 1)
  129. domain = _filenameEscape(rest.split('/', 1)[0])
  130. url_hash = _safe64(server_url)
  131. if handle:
  132. handle_hash = _safe64(handle)
  133. else:
  134. handle_hash = ''
  135. filename = '%s-%s-%s-%s' % (proto, domain, url_hash, handle_hash)
  136. return os.path.join(self.association_dir, filename)
  137. def storeAssociation(self, server_url, association):
  138. """Store an association in the association directory.
  139. (str, Association) -> NoneType
  140. """
  141. association_s = association.serialize() # NOTE: UTF-8 encoded bytes
  142. filename = self.getAssociationFilename(server_url, association.handle)
  143. tmp_file, tmp = self._mktemp()
  144. try:
  145. try:
  146. tmp_file.write(association_s)
  147. os.fsync(tmp_file.fileno())
  148. finally:
  149. tmp_file.close()
  150. try:
  151. os.rename(tmp, filename)
  152. except OSError as why:
  153. if why.errno != EEXIST:
  154. raise
  155. # We only expect EEXIST to happen only on Windows. It's
  156. # possible that we will succeed in unlinking the existing
  157. # file, but not in putting the temporary file in place.
  158. try:
  159. os.unlink(filename)
  160. except OSError as why:
  161. if why.errno == ENOENT:
  162. pass
  163. else:
  164. raise
  165. # Now the target should not exist. Try renaming again,
  166. # giving up if it fails.
  167. os.rename(tmp, filename)
  168. except:
  169. # If there was an error, don't leave the temporary file
  170. # around.
  171. _removeIfPresent(tmp)
  172. raise
  173. def getAssociation(self, server_url, handle=None):
  174. """Retrieve an association. If no handle is specified, return
  175. the association with the latest expiration.
  176. (str, str or NoneType) -> Association or NoneType
  177. """
  178. if handle is None:
  179. handle = ''
  180. # The filename with the empty handle is a prefix of all other
  181. # associations for the given server URL.
  182. filename = self.getAssociationFilename(server_url, handle)
  183. if handle:
  184. return self._getAssociation(filename)
  185. else:
  186. association_files = os.listdir(self.association_dir)
  187. matching_files = []
  188. # strip off the path to do the comparison
  189. name = os.path.basename(filename)
  190. for association_file in association_files:
  191. if association_file.startswith(name):
  192. matching_files.append(association_file)
  193. matching_associations = []
  194. # read the matching files and sort by time issued
  195. for name in matching_files:
  196. full_name = os.path.join(self.association_dir, name)
  197. association = self._getAssociation(full_name)
  198. if association is not None:
  199. matching_associations.append(
  200. (association.issued, association))
  201. matching_associations.sort()
  202. # return the most recently issued one.
  203. if matching_associations:
  204. (_, assoc) = matching_associations[-1]
  205. return assoc
  206. else:
  207. return None
  208. def _getAssociation(self, filename):
  209. try:
  210. assoc_file = open(filename, 'rb')
  211. except IOError as why:
  212. if why.errno == ENOENT:
  213. # No association exists for that URL and handle
  214. return None
  215. else:
  216. raise
  217. try:
  218. assoc_s = assoc_file.read()
  219. finally:
  220. assoc_file.close()
  221. try:
  222. association = Association.deserialize(assoc_s)
  223. except ValueError:
  224. _removeIfPresent(filename)
  225. return None
  226. # Clean up expired associations
  227. if association.expiresIn == 0:
  228. _removeIfPresent(filename)
  229. return None
  230. else:
  231. return association
  232. def removeAssociation(self, server_url, handle):
  233. """Remove an association if it exists. Do nothing if it does not.
  234. (str, str) -> bool
  235. """
  236. assoc = self.getAssociation(server_url, handle)
  237. if assoc is None:
  238. return 0
  239. else:
  240. filename = self.getAssociationFilename(server_url, handle)
  241. return _removeIfPresent(filename)
  242. def useNonce(self, server_url, timestamp, salt):
  243. """Return whether this nonce is valid.
  244. str -> bool
  245. """
  246. if abs(timestamp - time.time()) > nonce.SKEW:
  247. return False
  248. if server_url:
  249. proto, rest = server_url.split('://', 1)
  250. else:
  251. # Create empty proto / rest values for empty server_url,
  252. # which is part of a consumer-generated nonce.
  253. proto, rest = '', ''
  254. domain = _filenameEscape(rest.split('/', 1)[0])
  255. url_hash = _safe64(server_url)
  256. salt_hash = _safe64(salt)
  257. filename = '%08x-%s-%s-%s-%s' % (timestamp, proto, domain, url_hash,
  258. salt_hash)
  259. filename = os.path.join(self.nonce_dir, filename)
  260. try:
  261. fd = os.open(filename, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o200)
  262. except OSError as why:
  263. if why.errno == EEXIST:
  264. return False
  265. else:
  266. raise
  267. else:
  268. os.close(fd)
  269. return True
  270. def _allAssocs(self):
  271. all_associations = []
  272. association_filenames = [
  273. os.path.join(self.association_dir, filename)
  274. for filename in os.listdir(self.association_dir)
  275. ]
  276. for association_filename in association_filenames:
  277. try:
  278. association_file = open(association_filename, 'rb')
  279. except IOError as why:
  280. if why.errno == ENOENT:
  281. logger.exception("%s disappeared during %s._allAssocs" % (
  282. association_filename, self.__class__.__name__))
  283. else:
  284. raise
  285. else:
  286. try:
  287. assoc_s = association_file.read()
  288. finally:
  289. association_file.close()
  290. # Remove expired or corrupted associations
  291. try:
  292. association = Association.deserialize(assoc_s)
  293. except ValueError:
  294. _removeIfPresent(association_filename)
  295. else:
  296. all_associations.append(
  297. (association_filename, association))
  298. return all_associations
  299. def cleanup(self):
  300. """Remove expired entries from the database. This is
  301. potentially expensive, so only run when it is acceptable to
  302. take time.
  303. () -> NoneType
  304. """
  305. self.cleanupAssociations()
  306. self.cleanupNonces()
  307. def cleanupAssociations(self):
  308. removed = 0
  309. for assoc_filename, assoc in self._allAssocs():
  310. if assoc.expiresIn == 0:
  311. _removeIfPresent(assoc_filename)
  312. removed += 1
  313. return removed
  314. def cleanupNonces(self):
  315. nonces = os.listdir(self.nonce_dir)
  316. now = time.time()
  317. removed = 0
  318. # Check all nonces for expiry
  319. for nonce_fname in nonces:
  320. timestamp = nonce_fname.split('-', 1)[0]
  321. timestamp = int(timestamp, 16)
  322. if abs(timestamp - now) > nonce.SKEW:
  323. filename = os.path.join(self.nonce_dir, nonce_fname)
  324. _removeIfPresent(filename)
  325. removed += 1
  326. return removed