Source code for sqliterc.schema_extractor

"""SQLite database file schema extractor."""

import logging
import os
import re
import sqlite3
import tempfile
import textwrap

from artifacts import definitions as artifacts_definitions
from artifacts import reader as artifacts_reader
from artifacts import registry as artifacts_registry

from dfimagetools import definitions as dfimagetools_definitions
from dfimagetools import file_entry_lister

from sqliterc import resources
from sqliterc import yaml_definitions_file


[docs] class SQLiteSchemaExtractor: """SQLite database file schema extractor.""" _DATABASE_DEFINITIONS_FILE = os.path.join( os.path.dirname(__file__), "data", "known_databases.yaml" ) _MINIMUM_FILE_SIZE = 16 _READ_BUFFER_SIZE = 16 * 1024 * 1024 _SCHEMA_QUERY = ( "SELECT tbl_name, sql " "FROM sqlite_master " 'WHERE type = "table" AND tbl_name != "xp_proc" ' 'AND tbl_name != "sqlite_sequence"' )
[docs] def __init__(self, artifact_definitions, mediator=None): """Initializes a SQLite database file schema extractor. Args: artifact_definitions (str): path to a single artifact definitions YAML file or a directory of definitions YAML files. mediator (Optional[dfvfs.VolumeScannerMediator]): a volume scanner mediator. """ super().__init__() self._artifacts_registry = artifacts_registry.ArtifactDefinitionsRegistry() self._known_database_definitions = {} self._mediator = mediator if artifact_definitions: reader = artifacts_reader.YamlArtifactsReader() if os.path.isdir(artifact_definitions): self._artifacts_registry.ReadFromDirectory(reader, artifact_definitions) elif os.path.isfile(artifact_definitions): self._artifacts_registry.ReadFromFile(reader, artifact_definitions) definitions_file = yaml_definitions_file.YAMLDatabaseDefinitionsFile() for database_definition in definitions_file.ReadFromFile( self._DATABASE_DEFINITIONS_FILE ): artifact_definition = self._artifacts_registry.GetDefinitionByName( database_definition.artifact_definition ) if not artifact_definition: logging.warning( ( f"Unknown artifact definition: " f"{database_definition.artifact_definition:s}" ) ) else: self._known_database_definitions[ database_definition.database_identifier ] = artifact_definition
def _CheckSignature(self, file_object): """Checks the signature of a given file-like object. Args: file_object (dfvfs.FileIO): file-like object of the SQLite database. Returns: bool: True if the signature matches that of a SQLite database, False otherwise. """ if not file_object: return False file_object.seek(0, os.SEEK_SET) file_data = file_object.read(16) return file_data == b"SQLite format 3\x00" def _FormatSchemaAsText(self, schema): """Formats a schema into a word-wrapped string. Args: schema (dict[str, str]): schema as an SQL query per table name. Returns: str: schema formatted as word-wrapped string. """ textwrapper = textwrap.TextWrapper() textwrapper.break_long_words = False textwrapper.drop_whitespace = True textwrapper.width = 80 - (10 + 4) lines = [] table_index = 1 number_of_tables = len(schema) for table_name, query in sorted(schema.items()): line = f" '{table_name:s}': (" lines.append(line) # Replace \t and \n by a space. query = re.sub(r"[\n\t]+", r" ", query, count=0) query = query.replace("'", "\\'") query = textwrapper.wrap(query) query = [f" '{line:s} '" for line in query] last_line = query[-1] if table_index == number_of_tables: query[-1] = "".join([last_line[:-2], "')}}]"]) else: query[-1] = "".join([last_line[:-2], "'),"]) lines.extend(query) table_index += 1 return "\n".join(lines) def _FormatSchemaAsYAML(self, schema): """Formats a schema into YAML. Args: schema (dict[str, str]): schema as an SQL query per table name. Returns: str: schema formatted as YAML. Raises: RuntimeError: if a query could not be parsed. """ lines = ["# SQLite-kb database schema."] for table_name, query in sorted(schema.items()): original_query = query # Replace \t by a space. query = re.sub(r"[\t]+", r" ", query, count=0) if not query.startswith("CREATE ") or query[-1] != ")": raise RuntimeError(f'Unsupported query: "{original_query:s}"') query = query[7:-1] if query.startswith("VIRTUAL "): continue if not query.startswith("TABLE "): raise RuntimeError(f'Unsupported query: "{original_query:s}"') query = query[6:] if query[0] == "'": query_start = f"'{table_name:s}'" elif query[0] == '"': query_start = f'"{table_name:s}"' elif query[0] == "`": query_start = f"`{table_name:s}`" elif query[0] == "[": query_start = f"[{table_name:s}]" else: query_start = table_name if not query.startswith(query_start): raise RuntimeError(f'Unsupported query: "{original_query:s}"') # Note that there can be a space between table name and "(". query = query[len(query_start) :].lstrip() if not query[0] == "(": raise RuntimeError(f'Unsupported query: "{original_query:s}"') # Note that there can be a space between "(" and the column definition. query = query[1:].lstrip() column_definitions = {} while query: # Strip comments. if query.startswith("-- "): _, _, query = query.partition("\n") query = query.lstrip() if query.startswith("CONSTRAINT"): break if query.startswith("UNIQUE"): # TODO: set unique status in column definition. break if query.startswith("PRIMARY KEY"): # TODO: set primary key status in column definition. break # TODO: handle CONSTRAINT column, _, query = query.partition(",") query = query.lstrip() column_segments = column.split(" ") column_name = column_segments[0] if column_name[0] in ("'", '"', "`", "["): column_name = column_name[1:-1] if column_name in column_definitions: raise RuntimeError(f"Column: {column_name:s} already defined.") column_definition = resources.ColumnDefinition() column_definition.name = column_name # Note that a column definition can be defined without a type. if len(column_segments) > 1: column_definition.value_type = column_segments[1] column_definitions[column_name] = column_definition lines.extend(["---", f"table: {table_name:s}", "columns:"]) for column_definition in column_definitions.values(): lines.append(f"- name: {column_definition.name:s}") if column_definition.value_type: lines.append(f" value_type: {column_definition.value_type:s}") lines.append("") return "\n".join(lines) def _GetDatabaseIdentifier(self, path_segments): """Determines the database identifier. Args: path_segments (list[str]): path segments. Returns: str: database identifier or None if the type could not be determined. """ # TODO: make comparison more efficient. for ( database_identifier, artifact_definition, ) in self._known_database_definitions.items(): for source in artifact_definition.sources: if source.type_indicator in ( artifacts_definitions.TYPE_INDICATOR_DIRECTORY, artifacts_definitions.TYPE_INDICATOR_FILE, artifacts_definitions.TYPE_INDICATOR_PATH, ): for source_path in set(source.paths): source_path_segments = source_path.split(source.separator) if not source_path_segments[0]: source_path_segments = source_path_segments[1:] # TODO: add support for parameters. last_index = len(source_path_segments) for index in range(1, last_index + 1): source_path_segment = source_path_segments[-index] if not source_path_segment or len(source_path_segment) < 2: continue if ( source_path_segment[0] == "%" and source_path_segment[-1] == "%" ): source_path_segments = source_path_segments[ -index + 1 : ] break if len(source_path_segments) > len(path_segments): continue is_match = True last_index = min(len(source_path_segments), len(path_segments)) for index in range(1, last_index + 1): source_path_segment = source_path_segments[-index] # TODO: improve handling of * if "*" in source_path_segment: continue path_segment = path_segments[-index].lower() source_path_segment = source_path_segment.lower() is_match = path_segment == source_path_segment if not is_match: break if is_match: return database_identifier return None def _GetDatabaseSchema(self, path): """Retrieves schema from given SQLite 3 database. Args: path (str): path to SQLite 3 database file. Returns: dict[str, str]: schema as an SQL query per table name or None if the schema could not be retrieved. """ schema = None database = sqlite3.connect(path) database.row_factory = sqlite3.Row try: cursor = database.cursor() rows = cursor.execute(self._SCHEMA_QUERY) schema = dict(rows) except sqlite3.DatabaseError as exception: logging.error(f"Unable to query schema with error: {exception!s}") finally: database.close() # TODO: move schema into object. return schema def _GetDatabaseSchemaFromFileObject(self, file_object): """Retrieves schema from given SQLite 3 database file-like object. Args: file_object (dfvfs.FileIO): file-like object of the SQLite 3 database. Returns: dict[str, str]: schema as an SQL query per table name or None if the schema could not be retrieved. """ # TODO: find an alternative solution that can read a SQLite database # directly from a file-like object. with tempfile.NamedTemporaryFile(delete=True) as temporary_file: file_object.seek(0, os.SEEK_SET) file_data = file_object.read(self._READ_BUFFER_SIZE) while file_data: temporary_file.write(file_data) file_data = file_object.read(self._READ_BUFFER_SIZE) return self._GetDatabaseSchema(temporary_file.name)
[docs] def GetDisplayPath(self, path_segments, data_stream_name=None): """Retrieves a path to display. Args: path_segments (list[str]): path segments of the full path of the file entry. data_stream_name (Optional[str]): name of the data stream. Returns: str: path to display. """ display_path = "" path_segments = [ segment.translate( dfimagetools_definitions.NON_PRINTABLE_CHARACTER_TRANSLATION_TABLE ) for segment in path_segments ] display_path = "".join([display_path, "/".join(path_segments)]) if data_stream_name: data_stream_name = data_stream_name.translate( dfimagetools_definitions.NON_PRINTABLE_CHARACTER_TRANSLATION_TABLE ) display_path = ":".join([display_path, data_stream_name]) return display_path or "/"
[docs] def ExtractSchemas(self, path, options=None): """Extracts database schemas from the path. Args: path (str): path of a SQLite 3 database file or storage media image containing SQLite 3 database files. options (Optional[dfvfs.VolumeScannerOptions]): volume scanner options. If None the default volume scanner options are used, which are defined in the dfVFS VolumeScannerOptions class. Yields: tuple[str, dict[str, str]]: known database type identifier or the name of the SQLite database file if not known and schema. """ entry_lister = file_entry_lister.FileEntryLister(mediator=self._mediator) base_path_specs = entry_lister.GetBasePathSpecs(path, options=options) if not base_path_specs: logging.warning( f"Unable to determine base path specifications from: {path:s}" ) else: for file_entry, path_segments in entry_lister.ListFileEntries( base_path_specs ): if not file_entry.IsFile() or file_entry.size < self._MINIMUM_FILE_SIZE: continue file_object = file_entry.GetFileObject() if not self._CheckSignature(file_object): continue display_path = self.GetDisplayPath(path_segments) # logging.info( # f'Extracting schema from database file: {display_path:s}') database_schema = self._GetDatabaseSchemaFromFileObject(file_object) if database_schema is None: logging.warning( ( f"Unable to determine schema from database file: " f"{display_path:s}" ) ) continue # TODO: improve support to determine identifier for single database # file. database_identifier = self._GetDatabaseIdentifier(path_segments) if not database_identifier: logging.warning( ( f"Unable to determine known database identifier of file: " f"{display_path:s}" ) ) database_identifier = path_segments[-1] yield database_identifier, database_schema
[docs] def FormatSchema(self, schema, output_format): """Formats a schema into the output format. Args: schema (dict[str, str]): schema as an SQL query per table name. output_format (str): output format. Returns: str: formatted schema. Raises: RuntimeError: if a query could not be parsed. """ if output_format == "text": return self._FormatSchemaAsText(schema) if output_format == "yaml": return self._FormatSchemaAsYAML(schema) raise RuntimeError(f"Unsupported output format: {output_format:s}")