Tracking Information Flow

We have explored how one could generate better inputs that can penetrate deeper into the program in question. While doing so, we have relied on program crashes to tell us that we have succeeded in finding problems in the program. However, that is rather simplistic. What if the behavior of the program is simply incorrect, but does not lead to a crash? Can one do better?

In this chapter, we explore in depth how to track information flows in Python, and how these flows can be used to determine whether a program behaved as expected.

Prerequisites

We first set up our infrastructure so that we can make use of previously defined functions.

Synopsis

To use the code provided in this chapter, write

>>> from fuzzingbook.InformationFlow import <identifier>

and then make use of the following features.

This chapter provides two wrappers to Python strings that allow one to track various properties. These include information on the security properties of the input, and information on originating indexes of the input string.

For tracking information on security properties, use tstr as follows:

>>> thello = tstr('hello', taint='LOW')

Now, any operation from thello that results in a string fragment would include the correct taint. For example:

>>> thello[1:2].taint
'LOW'

For tracking the originating indexes from the input string, use ostr as follows:

>>> ohw = ostr("hello\tworld", origin=100)

The originating indexes can be recovered as follows:

>>> (ohw[0:4] +"-"+ ohw[6:]).origin
[100, 101, 102, 103, -1, 106, 107, 108, 109, 110]

A Vulnerable Database

Say we want to implement an in-memory database service in Python. Here is a rather flimsy attempt. We use the following dataset.

INVENTORY = """\
1997,van,Ford,E350
2000,car,Mercury,Cougar
1999,car,Chevy,Venture\
"""
VEHICLES = INVENTORY.split('\n')

Our DB is a Python class that parses its arguments and throws SQLException which is defined below.

class SQLException(Exception):
    pass

The database is simply a Python dict that is exposed only through SQL queries.

class DB:
    def __init__(self, db={}):
        self.db = dict(db)

Representing Tables

The database contains tables, which are created by a method call create_table(). Each table data structure is a pair of values. The first one is the meta data containing column names and types. The second value is a list of values in the table.

class DB(DB):
    def create_table(self, table, defs):
        self.db[table] = (defs, [])

The table can be retrieved using the name using the table() method call.

class DB(DB):
    def table(self, t_name):
        if t_name in self.db:
            return self.db[t_name]
        raise SQLException('Table (%s) was not found' % repr(t_name))

Here is an example of how to use both. We fill a table inventory with four columns: year, kind, company, and model. Initially, our table is empty.

def sample_db():
    db = DB()
    inventory_def = {'year': int, 'kind': str, 'company': str, 'model': str}
    db.create_table('inventory', inventory_def)
    return db

Using table(), we can retrieve the table definition as well as its contents.

db = sample_db()
db.table('inventory')
({'year': int, 'kind': str, 'company': str, 'model': str}, [])

We also define column() for retrieving the column definition from a table declaration.

class DB(DB):
    def column(self, table_decl, c_name):
        if c_name in table_decl: 
            return table_decl[c_name]
        raise SQLException('Column (%s) was not found' % repr(c_name))
db = sample_db()
decl, rows = db.table('inventory')
db.column(decl, 'year')
int

Executing SQL Statements

The sql() method of DB executes SQL statements. It inspects its arguments, and dispatches the query based on the kind of SQL statement to be executed.

class DB(DB):
    def do_select(self, query):
        assert False
    def do_update(self, query):
        assert False
    def do_insert(self, query):
        assert False
    def do_delete(self, query):
        assert False

    def sql(self, query):
        methods = [('select ', self.do_select), 
                   ('update ', self.do_update),
                   ('insert into ', self.do_insert),
                   ('delete from', self.do_delete)]
        for key, method in methods:
            if query.startswith(key):
                return method(query[len(key):])
        raise SQLException('Unknown SQL (%s)' % query)

At this point, the individual methods for handling SQL statements are not yet defined. Let us do this in the next steps.

Selecting Data

The do_select() method handles SQL select statements to retrieve data from a table.

class DB(DB):
    def do_select(self, query):
        FROM, WHERE = ' from ', ' where '
        table_start = query.find(FROM)
        if table_start < 0:
            raise SQLException('no table specified')

        where_start = query.find(WHERE)
        select = query[:table_start]

        if where_start >= 0:
            t_name = query[table_start + len(FROM):where_start]
            where = query[where_start + len(WHERE):]
        else:
            t_name = query[table_start + len(FROM):]
            where = ''
        _, table = self.table(t_name)

        if where:
            selected = self.expression_clause(table, "(%s)" % where)
            selected_rows = [hm for i, data, hm in selected if data]
        else:
            selected_rows = table

        rows = self.expression_clause(selected_rows, "(%s)" % select)
        return [data for i, data, hm in rows]

The expression_clause() method is used for two purposes:

  1. In the form select $x$, $y$, $z$ from $t$, it evaluates (and returns) the expressions $x$, $y$, $z$ in the contexts of the selected rows.
  2. If a clause where $p$ is given, it also evaluates $p$ in the context of the rows and includes the rows in the selection only if $p$ holds.

To evaluate expressions like $x$, $y$, $z$ or $p$, we make use of the Python evaluation function.

class DB(DB):
    def expression_clause(self, table, statement):
        selected = []
        for i, hm in enumerate(table):
            selected.append((i, self.my_eval(statement, {}, hm), hm))

        return selected

Which internally calls my_eval() to evaluate any given statement.

class DB(DB):
    def my_eval(self, statement, g, l):
        try:
            return eval(statement, g, l)
        except:
            raise SQLException('Invalid WHERE (%s)' % repr(statement))

Note: Using eval() here introduces some important security issues, which we will discuss later in this chapter.

Here's how we can use sql() to issue a query. Note that the table is yet empty.

db = sample_db()
db.sql('select year from inventory')
[]
db = sample_db()
db.sql('select year from inventory where year == 2018')
[]

Inserting Data

The do_insert() method handles SQL insert statements.

class DB(DB):
    def do_insert(self, query):
        VALUES = ' values '
        table_end = query.find('(')
        t_name = query[:table_end].strip()
        names_end = query.find(')')
        decls, table = self.table(t_name)
        names = [i.strip() for i in query[table_end + 1:names_end].split(',')]

        # verify columns exist
        for k in names:
            self.column(decls, k)

        values_start = query.find(VALUES)

        if values_start < 0:
            raise SQLException('Invalid INSERT (%s)' % repr(query))

        values = [
            i.strip() for i in query[values_start + len(VALUES) + 1:-1].split(',')
        ]

        if len(names) != len(values):
            raise SQLException(
                'names(%s) != values(%s)' % (repr(names), repr(values)))

        kvs = {k: self.convert(decls[k], v) for k, v in zip(names, values)}
        table.append(kvs)

In SQL, a column can come in any supported data type. To ensure it is stored using the type originally declared, we need the ability to convert the values to specific types which is provided by convert().

import ast
class DB(DB):
    def convert(self, cast, value):
        try:
            return cast(ast.literal_eval(value))
        except:
            raise SQLException('Invalid Conversion %s(%s)' % (cast, value))

Here is an example of how to use the SQL insert command:

db = sample_db()
db.sql('insert into inventory (year, kind, company, model) values (1997, "van", "Ford", "E350")')
db.table('inventory')
({'year': int, 'kind': str, 'company': str, 'model': str},
 [{'year': 1997, 'kind': 'van', 'company': 'Ford', 'model': 'E350'}])

With the database filled, we can also run more complex queries:

db.sql('select year + 1, kind from inventory')
[(1998, 'van')]
db.sql('select year, kind from inventory where year == 1997')
[(1997, 'van')]

Updating Data

Similarly, do_update() handles SQL update statements.

class DB(DB):
    def do_update(self, query):
        SET, WHERE = ' set ', ' where '
        table_end = query.find(SET)

        if table_end < 0:
            raise SQLException('Invalid UPDATE (%s)' % repr(query))

        set_end = table_end + 5
        t_name = query[:table_end]
        decls, table = self.table(t_name)
        names_end = query.find(WHERE)

        if names_end >= 0:
            names = query[set_end:names_end]
            where = query[names_end + len(WHERE):]
        else:
            names = query[set_end:]
            where = ''

        sets = [[i.strip() for i in name.split('=')]
                for name in names.split(',')]

        # verify columns exist
        for k, v in sets:
            self.column(decls, k)

        if where:
            selected = self.expression_clause(table, "(%s)" % where)
            updated = [hm for i, d, hm in selected if d]
        else:
            updated = table

        for hm in updated:
            for k, v in sets:
                hm[k] = self.convert(decls[k], v)

        return "%d records were updated" % len(updated)

Here is an example. Let us first fill the database again with values:

db = sample_db()
db.sql('insert into inventory (year, kind, company, model) values (1997, "van", "Ford", "E350")')
db.sql('select year from inventory')
[1997]

Now we can update things:

db.sql('update inventory set year = 1998 where year == 1997')
db.sql('select year from inventory')
[1998]
db.table('inventory')
({'year': int, 'kind': str, 'company': str, 'model': str},
 [{'year': 1998, 'kind': 'van', 'company': 'Ford', 'model': 'E350'}])

Deleting Data

Finally, SQL delete statements are handled by do_delete().

class DB(DB):
    def do_delete(self, query):
        WHERE = ' where '
        table_end = query.find(WHERE)
        if table_end < 0:
            raise SQLException('Invalid DELETE (%s)' % query)
        t_name = query[:table_end].strip()
        _, table = self.table(t_name)
        where = query[table_end + len(WHERE):]
        selected = self.expression_clause(table, "%s" % where)
        deleted = [i for i, d, hm in selected if d]
        for i in sorted(deleted, reverse=True):
            del table[i]
        return "%d records were deleted" % len(deleted)

Here is an example. Let us first fill the database again with values:

db = sample_db()
db.sql('insert into inventory (year, kind, company, model) values (1997, "van", "Ford", "E350")')
db.sql('select year from inventory')
[1997]

Now we can delete data:

db.sql('delete from inventory where company == "Ford"')
'1 records were deleted'

Our database is now empty:

db.sql('select year from inventory')
[]

All Methods Together

Here is how our database can be used.

db = DB()

Again, we first create a table in our database with the correct data types.

inventory_def = {'year': int, 'kind': str, 'company': str, 'model': str}
db.create_table('inventory', inventory_def)

Here is a simple convenience function to update the table using our dataset.

def update_inventory(sqldb, vehicle):
    inventory_def = sqldb.db['inventory'][0]
    k, v = zip(*inventory_def.items())
    val = [repr(cast(val)) for cast, val in zip(v, vehicle.split(','))]
    sqldb.sql('insert into inventory (%s) values (%s)' % (','.join(k),
                                                          ','.join(val)))
for V in VEHICLES:
    update_inventory(db, V)

Our database now contains the same dataset as VEHICLES under INVENTORY table.

db.db
{'inventory': ({'year': int, 'kind': str, 'company': str, 'model': str},
  [{'year': 1997, 'kind': 'van', 'company': 'Ford', 'model': 'E350'},
   {'year': 2000, 'kind': 'car', 'company': 'Mercury', 'model': 'Cougar'},
   {'year': 1999, 'kind': 'car', 'company': 'Chevy', 'model': 'Venture'}])}

Here is a sample select statement.

db.sql('select year,kind from inventory')
[(1997, 'van'), (2000, 'car'), (1999, 'car')]
db.sql("select company,model from inventory where kind == 'car'")
[('Mercury', 'Cougar'), ('Chevy', 'Venture')]

We can run updates on it.

db.sql("update inventory set year = 1998, company = 'Suzuki' where kind == 'van'")
'1 records were updated'
db.db
{'inventory': ({'year': int, 'kind': str, 'company': str, 'model': str},
  [{'year': 1998, 'kind': 'van', 'company': 'Suzuki', 'model': 'E350'},
   {'year': 2000, 'kind': 'car', 'company': 'Mercury', 'model': 'Cougar'},
   {'year': 1999, 'kind': 'car', 'company': 'Chevy', 'model': 'Venture'}])}

It can even do mathematics on the fly!

db.sql('select int(year)+10 from inventory')
[2008, 2010, 2009]

Adding a new row to our table.

db.sql("insert into inventory (year, kind, company, model) values (1, 'charriot', 'Rome', 'Quadriga')")
db.db
{'inventory': ({'year': int, 'kind': str, 'company': str, 'model': str},
  [{'year': 1998, 'kind': 'van', 'company': 'Suzuki', 'model': 'E350'},
   {'year': 2000, 'kind': 'car', 'company': 'Mercury', 'model': 'Cougar'},
   {'year': 1999, 'kind': 'car', 'company': 'Chevy', 'model': 'Venture'},
   {'year': 1, 'kind': 'charriot', 'company': 'Rome', 'model': 'Quadriga'}])}

Which we then delete.

db.sql("delete from inventory where year < 1900")
'1 records were deleted'

Fuzzing SQL

To verify that everything is OK, let us fuzz. First we define our grammar.

import string
EXPR_GRAMMAR = {
    "<start>": ["<expr>"],
    "<expr>": ["<bexpr>", "<aexpr>", "(<expr>)", "<term>"],
    "<bexpr>": [
        "<aexpr><lt><aexpr>",
        "<aexpr><gt><aexpr>",
        "<expr>==<expr>",
        "<expr>!=<expr>",
    ],
    "<aexpr>": [
        "<aexpr>+<aexpr>", "<aexpr>-<aexpr>", "<aexpr>*<aexpr>",
        "<aexpr>/<aexpr>", "<word>(<exprs>)", "<expr>"
    ],
    "<exprs>": ["<expr>,<exprs>", "<expr>"],
    "<lt>": ["<"],
    "<gt>": [">"],
    "<term>": ["<number>", "<word>"],
    "<number>": ["<integer>.<integer>", "<integer>", "-<number>"],
    "<integer>": ["<digit><integer>", "<digit>"],
    "<word>": ["<word><letter>", "<word><digit>", "<letter>"],
    "<digit>":
    list(string.digits),
    "<letter>":
    list(string.ascii_letters + '_:.')
}
INVENTORY_GRAMMAR = dict(
    EXPR_GRAMMAR, **{
        '<start>': ['<query>'],
        '<query>': [
            'select <exprs> from <table>',
            'select <exprs> from <table> where <bexpr>',
            'insert into <table> (<names>) values (<literals>)',
            'update <table> set <assignments> where <bexpr>',
            'delete from <table> where <bexpr>',
        ],
        '<table>': ['<word>'],
        '<names>': ['<column>,<names>', '<column>'],
        '<column>': ['<word>'],
        '<literals>': ['<literal>', '<literal>,<literals>'],
        '<literal>': ['<number>', "'<chars>'"],
        '<assignments>': ['<kvp>,<assignments>', '<kvp>'],
        '<kvp>': ['<column>=<value>'],
        '<value>': ['<word>'],
        '<chars>': ['<char>', '<char><chars>'],
        '<char>':
        [i for i in string.printable if i not in "<>'\"\t\n\r\x0b\x0c\x00"
         ] + ['<lt>', '<gt>'],
    })

As can be seen from the source of our database, the functions always check whether the table name is correct. Hence, we modify the grammar to choose our particular table so that it will have a better chance of reaching deeper. We will see in the later sections how this can be done automatically.

INVENTORY_GRAMMAR_F = dict(INVENTORY_GRAMMAR, **{'<table>': ['inventory']})
from GrammarFuzzer import GrammarFuzzer
gf = GrammarFuzzer(INVENTORY_GRAMMAR_F)
for _ in range(10):
    query = gf.fuzz()
    print(repr(query))
    try:
        res = db.sql(query)
        print(repr(res))
    except SQLException as e:
        print("> ", e)
        pass
    except:
        traceback.print_exc()
        break
    print()
'select O6fo,-977091.1,-36.46 from inventory'
>  Invalid WHERE ('(O6fo,-977091.1,-36.46)')

'select g3 from inventory where -3.0!=V/g/b+Q*M*G'
>  Invalid WHERE ('(-3.0!=V/g/b+Q*M*G)')

'update inventory set z=a,x=F_,Q=K where p(M)<_*S'
>  Column ('z') was not found

'update inventory set R=L5pk where e*l*y-u>K+U(:)'
>  Column ('R') was not found

'select _/d*Q+H/d(k)<t+M-A+P from inventory'
>  Invalid WHERE ('(_/d*Q+H/d(k)<t+M-A+P)')

'select F5 from inventory'
>  Invalid WHERE ('(F5)')

'update inventory set jWh.=a6 where wcY(M)>IB7(i)'
>  Column ('jWh.') was not found

'update inventory set U=y where L(W<c,(U!=W))<V(((q)==m<F),O,l)'
>  Column ('U') was not found

'delete from inventory where M/b-O*h*E<H-W>e(Y)-P'
>  Invalid WHERE ('M/b-O*h*E<H-W>e(Y)-P')

'select ((kP(86)+b*S+J/Z/U+i(U))) from inventory'
>  Invalid WHERE ('(((kP(86)+b*S+J/Z/U+i(U))))')

Fuzzing does not seem to have triggered any crashes. However, are crashes the only errors that we should be worried about?

The Evil of Eval

In our implementation, we have made use of eval() to evaluate expressions using the Python interpreter. This allows us to unleash the full power of Python expressions within our SQL statements.

db.sql('select year from inventory where year < 2000')
[1998, 1999]

In the above query, the clause year < 2000 is evaluated using expression_clause() using Python in the context of each row; hence, year < 2000 evaluates to either True or False.

The same holds for the expressions being selected:

db.sql('select year - 1900 if year < 2000 else year - 2000 from inventory')
[98, 0, 99]

This works because year - 1900 if year < 2000 else year - 2000 is a valid Python expression. (It is not a valid SQL expression, though.)

The problem with the above is that there is no limitation to what the Python expression can do. What if the user tries the following?

db.sql('select __import__("os").popen("pwd").read() from inventory')
['/Users/zeller/Projects/fuzzingbook/notebooks\n',
 '/Users/zeller/Projects/fuzzingbook/notebooks\n',
 '/Users/zeller/Projects/fuzzingbook/notebooks\n']

The above statement effectively reads from the users' file system. Instead of os.popen("pwd").read(), it could execute arbitrary Python commands – to access data, install software, run a background process. This is where "the full power of Python expressions" turns back on us.

What we want is to allow our program to make full use of its power; yet, the user (or any third party) should not be entrusted to do the same. Hence, we need to differentiate between (trusted) input from the program and (untrusted) input from the user.

One method that allows such differentiation is that of dynamic taint analysis. The idea is to identify the functions that accept user input as sources that taint any string that comes in through them, and those functions that perform dangerous operations as sinks. Finally we bless certain functions as taint sanitizers. The idea is that an input from the source should never reach the sink without undergoing sanitization first. This allows us to use a stronger oracle than simply checking for crashes.

Tracking String Taints

There are various levels of taint tracking that one can perform. The simplest is to track that a string fragment originated in a specific environment, and has not undergone a taint removal process. For this, we simply need to wrap the original string with an environment identifier (the taint) with tstr, and produce tstr instances on each operation that results in another string fragment. The attribute taint holds a label identifying the environment this instance was derived.

A Class for Tainted Strings

For capturing information flows we need a new string class. The idea is to use the new tainted string class tstr as a wrapper on the original str class. However, str is an immutable class. Hence, it does not call its __init__() method after being constructed. This means that any subclasses of str also will not get the __init__() method called. If we want to get our initialization routine called, we need to hook into __new__() and return an instance of our own class. We combine this with our initialization code in __init__().

class tstr(str):
    def __new__(cls, value, *args, **kw):
        return str.__new__(cls, value)

    def __init__(self, value, taint=None, **kwargs):
        self.taint = taint
class tstr(tstr):
    def __repr__(self):
        return tstr(str.__repr__(self), taint=self.taint)
class tstr(tstr):
    def __str__(self):
        return str.__str__(self)

For example, if we wrap "hello" in tstr, then we should be able to access its taint:

thello = tstr('hello', taint='LOW')
thello.taint
'LOW'
repr(thello).taint
'LOW'

By default, when we wrap a string, it is tainted. Hence we also need a way to clear the taint in the string. One way is to simply return a str instance as above. However, one may sometimes wish to remove the taint from an existing instance. This is accomplished with clear_taint(). During clear_taint(), we simply set the taint to None. This method comes with a pair method has_taint() which checks whether a tstr instance is currently origined.

class tstr(tstr):
    def clear_taint(self):
        self.taint = None
        return self

    def has_taint(self):
        return self.taint is not None

String Operators

To propagate the taint, we have to extend string functions, such as operators. We can do so in one single big step, overloading all string methods and operators.

When we create a new string from an existing tainted string, we propagate its taint.

class tstr(tstr):
    def create(self, s):
        return tstr(s, taint=self.taint)

The make_str_wrapper() function creates a wrapper around an existing string method which attaches the taint to the result of the method:

def make_str_wrapper(fun):
    def proxy(self, *args, **kwargs):
        res = fun(self, *args, **kwargs)
        return self.create(res)
    return proxy

We do this for all string methods that return a string:

for name in ['__format__', '__mod__', '__rmod__', '__getitem__', '__add__', '__mul__', '__rmul__',
             'capitalize', 'casefold', 'center', 'encode',
             'expandtabs', 'format', 'format_map', 'join', 'ljust', 'lower', 'lstrip', 'replace',
             'rjust', 'rstrip', 'strip', 'swapcase', 'title', 'translate', 'upper']:
    fun = getattr(str, name)
    setattr(tstr, name, make_str_wrapper(fun))

The one missing operator is + with a regular string on the left side and a tainted string on the right side. Python supports a __radd__() method which is invoked if the associated object is used on the right side of an addition.

class tstr(tstr):
    def __radd__(self, s):
        return self.create(s + str(self))

With this, we are already done. Let us create a string thello with a taint LOW.

thello = tstr('hello', taint='LOW')

Now, any substring will also be tainted:

thello[0].taint
'LOW'
thello[1:3].taint
'LOW'

String additions will return a tstr object with the taint:

(tstr('foo', taint='HIGH') + 'bar').taint
'HIGH'

Our __radd__() method ensures this also works if the tstr occurs on the right side of a string addition:

('foo' + tstr('bar', taint='HIGH')).taint
'HIGH'
thello += ', world'
thello.taint
'LOW'

Other operators such as multiplication also work:

(thello * 5).taint
'LOW'
('hw %s' % thello).taint
'LOW'
(tstr('hello %s', taint='HIGH') % 'world').taint
'HIGH'
import string

Tracking Untrusted Input

So, what can one do with tainted strings? We reconsider the DB example. We define a "better" TrustedDB which only accepts strings tainted as "TRUSTED".

class TrustedDB(DB):
    def sql(self, s):
        assert isinstance(s, tstr), "Need a tainted string"
        assert s.taint == 'TRUSTED', "Need a string with trusted taint"
        return super().sql(s)

Feeding a string with an "unknown" (i.e., non-existing) trust level will cause TrustedDB to fail:

bdb = TrustedDB(db.db)
from ExpectError import ExpectError
with ExpectError():
    bdb.sql("select year from INVENTORY")
Traceback (most recent call last):
  File "<ipython-input-79-65a521f9999f>", line 2, in <module>
    bdb.sql("select year from INVENTORY")
  File "<ipython-input-76-53a654b6cc10>", line 3, in sql
    assert isinstance(s, tstr), "Need a tainted string"
AssertionError: Need a tainted string (expected)

Additionally any user input would be originally tagged with "UNTRUSTED" as taint. If we place an untrusted string into our better calculator, it will also fail:

bad_user_input = tstr('__import__("os").popen("ls").read()', taint='UNTRUSTED')
with ExpectError():
    bdb.sql(bad_user_input)
Traceback (most recent call last):
  File "<ipython-input-80-82c5b2d628ed>", line 3, in <module>
    bdb.sql(bad_user_input)
  File "<ipython-input-76-53a654b6cc10>", line 4, in sql
    assert s.taint == 'TRUSTED', "Need a string with trusted taint"
AssertionError: Need a string with trusted taint (expected)

Hence, somewhere along the computation, we have to turn the "untrusted" inputs into "trusted" strings. This process is called sanitization. A simple sanitization function for our purposes could ensure that the input consists only of few allowed characters (not including letters or quotes); if this is the case, then the input gets a new "TRUSTED" taint. If not, we turn the string into an (untrusted) empty string; other alternatives would be to raise an error or to escape or delete "untrusted" characters.

import re
def sanitize(user_input):
    assert isinstance(user_input, tstr)
    if re.match(
            r'^select +[-a-zA-Z0-9_, ()]+ from +[-a-zA-Z0-9_, ()]+$', user_input):
        return tstr(user_input, taint='TRUSTED')
    else:
        return tstr('', taint='UNTRUSTED')
good_user_input = tstr("select year,model from inventory", taint='UNTRUSTED')
sanitized_input = sanitize(good_user_input)
sanitized_input
'select year,model from inventory'
sanitized_input.taint
'TRUSTED'
bdb.sql(sanitized_input)
[(1998, 'E350'), (2000, 'Cougar'), (1999, 'Venture')]

Let us now try out our untrusted input:

sanitized_input = sanitize(bad_user_input)
sanitized_input
''
sanitized_input.taint
'UNTRUSTED'
with ExpectError():
    bdb.sql(sanitized_input)
Traceback (most recent call last):
  File "<ipython-input-88-e59f9e5c9d30>", line 2, in <module>
    bdb.sql(sanitized_input)
  File "<ipython-input-76-53a654b6cc10>", line 4, in sql
    assert s.taint == 'TRUSTED', "Need a string with trusted taint"
AssertionError: Need a string with trusted taint (expected)

In a similar fashion, we can prevent SQL and code injections discussed in the chapter on Web fuzzing.

Taint Aware Fuzzing

We can also use tainting to direct fuzzing to those grammar rules that are likely to generate dangerous inputs. The idea here is to identify inputs generated by our fuzzer that lead to untrusted execution. First we define the exception to be thrown when a tainted value reaches a dangerous operation.

class Tainted(Exception):
    def __init__(self, v):
        self.v = v

    def __str__(self):
        return 'Tainted[%s]' % self.v

TaintedDB

Next, since my_eval() is the most dangerous operation in the DB class, we define a new class TaintedDB that overrides the my_eval() to throw an exception whenever an untrusted string reaches this part.

class TaintedDB(DB):
    def my_eval(self, statement, g, l):
        if statement.taint != 'TRUSTED':
            raise Tainted(statement)
        try:
            return eval(statement, g, l)
        except:
            raise SQLException('Invalid SQL (%s)' % repr(statement))

We initialize an instance of TaintedDB

tdb = TaintedDB()
tdb.db = db.db

Then we start fuzzing.

import traceback
for _ in range(10):
    query = gf.fuzz()
    print(repr(query))
    try:
        res = tdb.sql(tstr(query, taint='UNTRUSTED'))
        print(repr(res))
    except SQLException as e:
        pass
    except Tainted as e:
        print("> ", e)
    except:
        traceback.print_exc()
        break
    print()
'delete from inventory where y/u-l+f/y<Y(c)/A-H*q'
>  Tainted[y/u-l+f/y<Y(c)/A-H*q]

"insert into inventory (G,Wmp,sl3hku3) values ('<','?')"

"insert into inventory (d0) values (',_G')"

'select P*Q-w/x from inventory where X<j==:==j*r-f'
>  Tainted[(X<j==:==j*r-f)]

'select a>F*i from inventory where Q/I-_+P*j>.'
>  Tainted[(Q/I-_+P*j>.)]

'select (V-i<T/g) from inventory where T/r/G<FK(m)/(i)'
>  Tainted[(T/r/G<FK(m)/(i))]

'select (((i))),_(S,_)/L-k<H(Sv,R,n,W,Y) from inventory'
>  Tainted[((((i))),_(S,_)/L-k<H(Sv,R,n,W,Y))]

'select (N==c*U/P/y),i-e/n*y,T!=w,u from inventory'
>  Tainted[((N==c*U/P/y),i-e/n*y,T!=w,u)]

'update inventory set _=B,n=v where o-p*k-J>T'

'select s from inventory where w4g4<.m(_)/_>t'
>  Tainted[(w4g4<.m(_)/_>t)]

One can see that insert, update, select and delete statements on an existing table lead to taint exceptions. We can now focus on these specific kinds of inputs. However, this is not the only thing we can do. We will see how we can identify specific portions of input that reached tainted execution using character origins in the later sections. But before that, we explore other uses of taints.

Preventing Privacy Leaks

Using taints, we can also ensure that secret information does not leak out. We can assign a special taint "SECRET" to strings whose information must not leak out:

secrets = tstr('<Plenty of secret keys>', taint='SECRET')

Accessing any substring of secrets will propagate the taint:

secrets[1:3].taint
'SECRET'

Consider the heartbeat security leak from the chapter on Fuzzing, in which a server would accidentally reply not only the user input sent to it, but also secret memory. If the reply consists only of the user input, there is no taint associated with it:

user_input = "hello"
reply = user_input
isinstance(reply, tstr)
False

If, however, the reply contains any part of the secret, the reply will be tainted:

reply = user_input + secrets[0:5]
reply
'hello<Plen'
reply.taint
'SECRET'

The output function of our server would now ensure that the data sent back does not contain any secret information:

def send_back(s):
    assert not isinstance(s, tstr) and not s.taint == 'SECRET'
    ...
with ExpectError():
    send_back(reply)
Traceback (most recent call last):
  File "<ipython-input-103-e02d8e55c3ba>", line 2, in <module>
    send_back(reply)
  File "<ipython-input-102-a105f7cd1cab>", line 2, in send_back
    assert not isinstance(s, tstr) and not s.taint == 'SECRET'
AssertionError (expected)

Tracking Character Origins

Our tstr solution can be help to identify information leaks – but it is by no means complete. If we actually take the heartbeat() implementation from the chapter on Fuzzing, we will see that any reply is marked as SECRET – even those not even accessing secret memory:

from Fuzzer import heartbeat
reply = heartbeat('hello', 5, memory=secrets)
reply.taint
'SECRET'

Why is this? If we look into the implementation of heartbeat(), we will see that it first builds a long string memory from the (non-secret) reply and the (secret) memory, before returning the first characters from memory.

# Store reply in memory
    memory = reply + memory[len(reply):]

At this point, the whole memory still is tainted as SECRET, including the non-secret part from reply.

We may be able to circumvent the issue by tagging the reply as PUBLIC – but then, this taint would be in conflict with the SECRET tag of memory. What happens if we compose a string from two differently tainted strings?

thilo = tstr("High", taint='HIGH') + tstr("Low", taint='LOW')

It turns out that in this case, the __add__() method takes precedence over the __radd__() method, which means that the right-hand "Low" string is treated as a regular (non-tainted) string.

thilo
'HighLow'
thilo.taint
'HIGH'

We could set up the __add__() and other methods with special handling for conflicting taints. However, the way this conflict should be resolved would be highly application-dependent:

  • If we use taints to indicate privacy levels, SECRET privacy should take precedence over PUBLIC privacy. Any combination of a SECRET-tainted string and a PUBLIC-tainted string thus should have a SECRET taint.

  • If we use taints to indicate origins of information, an UNTRUSTED origin should take precedence over a TRUSTED origin. Any combination of an UNTRUSTED-tainted string and a TRUSTED-tainted string thus should have an UNTRUSTED taint.

Of course, such conflict resolutions can be implemented. But even so, they will not help us in the heartbeat() example differentiating secret from non-secret output data.

Tracking Individual Characters

Fortunately, there is a better, more generic way to solve the above problems. The key to composition of differently tainted strings is to assign taints not only to strings, but actually to every bit of information – in our case, characters. If every character has a taint on its own, a new composition of characters will simply inherit this very taint per character. To this end, we introduce a second bit of information named origin.

Distinguishing various untrusted sources may be accomplished by origining each instances as separate instances (called colors in dynamic origin research). You will see an instance of this technique in the chapter on Grammar Mining.

In this section, we carry character level origins. That is, given a fragment that resulted from a portion of the original origined string, one will be able to tell which portion of the input string the fragment was taken from. In essence, each input character index from a origined source gets its own color.

More complex origining such as bitmap origins are possible where a single character may result from multiple origined character indexes (such as checksum operations on strings). We do not consider these in this chapter.

A Class for Tracking Character Origins

Let us introduce a class ostr which, like tstr, carries a taint for each string, and additionally an origin for each character that indicates its source. It is a consecutive number in a particular range (by default, starting with zero) indicating its position within a specific origin.

class ostr(str):
    DEFAULT_ORIGIN = 0

    def __new__(cls, value, *args, **kw):
        return str.__new__(cls, value)

    def __init__(self, value, taint=None, origin=None, **kwargs):
        self.taint = taint

        if origin is None:
            origin = ostr.DEFAULT_ORIGIN
        if isinstance(origin, int):
            self.origin = list(range(origin, origin + len(self)))
        else:
            self.origin = origin
        assert len(self.origin) == len(self)
class ostr(ostr):
    def create(self, s):
        return ostr(s, taint=self.taint, origin=self.origin)
class ostr(ostr):
    UNKNOWN_ORIGIN = -1

    def __repr__(self):
        # handle escaped chars
        origin = [ostr.UNKNOWN_ORIGIN]
        for s, o in zip(str(self), self.origin):
            origin.extend([o] * (len(repr(s)) - 2))
        origin.append(ostr.UNKNOWN_ORIGIN)
        return ostr(str.__repr__(self), taint=self.taint, origin=origin)
class ostr(ostr):
    def __str__(self):
        return str.__str__(self)

By default, character origins start with 0:

thello = ostr('hello')
assert thello.origin == [0, 1, 2, 3, 4]

We can also specify the starting origin as below -- 6..10

tworld = ostr('world', origin=6)
assert tworld.origin == [6, 7, 8, 9, 10]
a = ostr("hello\tworld")
repr(a).origin
[-1, 0, 1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 10, -1]

str() returns an str instance without origin or taint information:

assert type(str(thello)) == str

repr(), however, keeps the origin information for the original string:

repr(thello)
"'hello'"
repr(thello).origin
[-1, 0, 1, 2, 3, 4, -1]

Just as with taints, we can clear origins and check whether an origin is present:

class ostr(ostr):
    def clear_taint(self):
        self.taint = None
        return self

    def has_taint(self):
        return self.taint is not None
class ostr(ostr):
    def clear_origin(self):
        self.origin = [self.UNKNOWN_ORIGIN] * len(self)
        return self

    def has_origin(self):
        return any(origin != self.UNKNOWN_ORIGIN for origin in self.origin)
thello = ostr('Hello')
assert thello.has_origin()
thello.clear_origin()
assert not thello.has_origin()

In the remainder of this section, we re-implement various string methods such that they also keep track of origins. If this is too tedious for you, jump right to the next section which gives a number of usage examples.

Create

We need to create new substrings that are wrapped in ostr objects. However, we also want to allow our subclasses to create their own instances. Hence we again provide a create() method that produces a new ostr instance.

class ostr(ostr):
    def create(self, res, origin=None):
        return ostr(res, taint=self.taint, origin=origin)
thello = ostr('hello', taint='HIGH')
tworld = thello.create('world', origin=6)
tworld.origin
[6, 7, 8, 9, 10]
tworld.taint
'HIGH'
assert (thello.origin, tworld.origin) == (
    [0, 1, 2, 3, 4], [6, 7, 8, 9, 10])

Index

In Python, indexing is provided through __getitem__(). Indexing on positive integers is simple enough. However, it has two additional wrinkles. The first is that, if the index is negative, that many characters are counted from the end of the string which lies just after the last character. That is, the last character has a negative index -1

class ostr(ostr):
    def __getitem__(self, key):
        res = super().__getitem__(key)
        if isinstance(key, int):
            key = len(self) + key if key < 0 else key
            return self.create(res, [self.origin[key]])
        elif isinstance(key, slice):
            return self.create(res, self.origin[key])
        else:
            assert False
hello = ostr('hello', taint='HIGH')
assert (hello[0], hello[-1]) == ('h', 'o')
hello[0].taint
'HIGH'

The other wrinkle is that __getitem__() can accept a slice. We discuss this next.

Slices

The Python slice operator [n:m] relies on the object being an iterator. Hence, we define the __iter__() method, which returns a custom iterator.

class ostr(ostr):
    def __iter__(self):
        return ostr_iterator(self)

The __iter__() method requires a supporting iterator object. The iterator is used to save the state of the current iteration, which it does by keeping a reference to the original ostr, and the current index of iteration _str_idx.

class ostr_iterator():
    def __init__(self, ostr):
        self._ostr = ostr
        self._str_idx = 0

    def __next__(self):
        if self._str_idx == len(self._ostr):
            raise StopIteration
        # calls ostr getitem should be ostr
        c = self._ostr[self._str_idx]
        assert isinstance(c, ostr)
        self._str_idx += 1
        return c

Bringing all these together:

thw = ostr('hello world', taint='HIGH')
thw[0:5]
'hello'
assert thw[0:5].has_taint()
assert thw[0:5].has_origin()
thw[0:5].taint
'HIGH'
thw[0:5].origin
[0, 1, 2, 3, 4]

Splits

def make_split_wrapper(fun):
    def proxy(self, *args, **kwargs):
        lst = fun(self, *args, **kwargs)
        return [self.create(elem) for elem in lst]
    return proxy
for name in ['split', 'rsplit', 'splitlines']:
    fun = getattr(str, name)
    setattr(ostr, name, make_split_wrapper(fun))
thello = ostr('hello world', taint='LOW')
thello == 'hello world'
True
thello.split()[0].taint
'LOW'

(Exercise for the reader: handle partitions, i.e., splitting a string by substrings)

Concatenation

If two origined strings are concatenated together, it may be desirable to transfer the origins from each to the corresponding portion of the resulting string. The concatenation of strings is accomplished by overriding __add__().

class ostr(ostr):
    def __add__(self, other):
        if isinstance(other, ostr):
            return self.create(str.__add__(self, other),
                               (self.origin + other.origin))
        else:
            return self.create(str.__add__(self, other),
                               (self.origin + [self.UNKNOWN_ORIGIN for i in other]))
thello = ostr("hello")
tworld = ostr("world", origin=6)
thw = thello + tworld
assert thw.origin == [0, 1, 2, 3, 4, 6, 7, 8, 9, 10]

What if a ostr is concatenated with a str?

space = "  "
th_w = thello + space + tworld
assert th_w.origin == [
    0,
    1,
    2,
    3,
    4,
    ostr.UNKNOWN_ORIGIN,
    ostr.UNKNOWN_ORIGIN,
    6,
    7,
    8,
    9,
    10]

One wrinkle here is that when adding a ostr and a str, the user may place the str first, in which case, the __add__() method will be called on the str instance. Not on the ostr instance. However, Python provides a solution. If one defines __radd__() on the ostr instance, that method will be called rather than str.__add__()

class ostr(ostr):
    def __radd__(self, other):
        origin = other.origin if isinstance(other, ostr) else [
            self.UNKNOWN_ORIGIN for i in other]
        return self.create(str.__add__(other, self), (origin + self.origin))

We test it out:

shello = "hello"
tworld = ostr("world")
thw = shello + tworld
assert thw.origin == [ostr.UNKNOWN_ORIGIN] * len(shello) + [0, 1, 2, 3, 4]

These methods: slicing and concatenation is sufficient to implement other string methods that result in a string, and does not change the character underneath (i.e no case change). Hence, we look at a helper method next.

Extract Origin String

Given a specific input index, the method x() extracts the corresponding origined portion from a ostr. As a convenience it supports slices along with ints.

class ostr(ostr):
    class TaintException(Exception):
        pass

    def x(self, i=0):
        if not self.origin:
            raise origin.TaintException('Invalid request idx')
        if isinstance(i, int):
            return [self[p]
                    for p in [k for k, j in enumerate(self.origin) if j == i]]
        elif isinstance(i, slice):
            r = range(i.start or 0, i.stop or len(self), i.step or 1)
            return [self[p]
                    for p in [k for k, j in enumerate(self.origin) if j in r]]
thw = ostr('hello world', origin=100)
assert thw.x(101) == ['e']
assert thw.x(slice(101, 105)) == ['e', 'l', 'l', 'o']

Replace

The replace() method replaces a portion of the string with another.

class ostr(ostr):
    def replace(self, a, b, n=None):
        old_origin = self.origin
        b_origin = b.origin if isinstance(
            b, ostr) else [self.UNKNOWN_ORIGIN] * len(b)
        mystr = str(self)
        i = 0
        while True:
            if n and i >= n:
                break
            idx = mystr.find(a)
            if idx == -1:
                break
            last = idx + len(a)
            mystr = mystr.replace(a, b, 1)
            partA, partB = old_origin[0:idx], old_origin[last:]
            old_origin = partA + b_origin + partB
            i += 1
        return self.create(mystr, old_origin)
my_str = ostr("aa cde aa")
res = my_str.replace('aa', 'bb')
assert res, res.origin == ('bb', 'cde', 'bb',
                           [self.UNKNOWN_ORIGIN, self.UNKNOWN_ORIGIN,
                            2, 3, 4, 5, 6,
                            self.UNKNOWN_ORIGIN, self.UNKNOWN_ORIGIN])
my_str = ostr("aa cde aa")
res = my_str.replace('aa', ostr('bb', origin=100))
assert (
    res, res.origin) == (
        ('bb cde bb'), [
            100, 101, 2, 3, 4, 5, 6, 100, 101])

Split

We essentially have to re-implement split operations, and split by space is slightly different from other splits.

class ostr(ostr):
    def _split_helper(self, sep, splitted):
        result_list = []
        last_idx = 0
        first_idx = 0
        sep_len = len(sep)

        for s in splitted:
            last_idx = first_idx + len(s)
            item = self[first_idx:last_idx]
            result_list.append(item)
            first_idx = last_idx + sep_len
        return result_list

    def _split_space(self, splitted):
        result_list = []
        last_idx = 0
        first_idx = 0
        sep_len = 0
        for s in splitted:
            last_idx = first_idx + len(s)
            item = self[first_idx:last_idx]
            result_list.append(item)
            v = str(self[last_idx:])
            sep_len = len(v) - len(v.lstrip(' '))
            first_idx = last_idx + sep_len
        return result_list

    def rsplit(self, sep=None, maxsplit=-1):
        splitted = super().rsplit(sep, maxsplit)
        if not sep:
            return self._split_space(splitted)
        return self._split_helper(sep, splitted)

    def split(self, sep=None, maxsplit=-1):
        splitted = super().split(sep, maxsplit)
        if not sep:
            return self._split_space(splitted)
        return self._split_helper(sep, splitted)
my_str = ostr('ab cdef ghij kl')
ab, cdef, ghij, kl = my_str.rsplit(sep=' ')
assert (ab.origin, cdef.origin, ghij.origin,
        kl.origin) == ([0, 1], [3, 4, 5, 6], [8, 9, 10, 11], [13, 14])

my_str = ostr('ab cdef ghij kl', origin=list(range(0, 15)))
ab, cdef, ghij, kl = my_str.rsplit(sep=' ')
assert(ab.origin, cdef.origin, kl.origin) == ([0, 1], [3, 4, 5, 6], [13, 14])
my_str = ostr('ab   cdef ghij    kl', origin=100, taint='HIGH')
ab, cdef, ghij, kl = my_str.rsplit()
assert (ab.origin, cdef.origin, ghij.origin,
        kl.origin) == ([100, 101], [105, 106, 107, 108], [110, 111, 112, 113],
                       [118, 119])

my_str = ostr('ab   cdef ghij    kl', origin=list(range(0, 20)), taint='HIGH')
ab, cdef, ghij, kl = my_str.split()
assert (ab.origin, cdef.origin, kl.origin) == ([0, 1], [5, 6, 7, 8], [18, 19])
assert ab.taint == 'HIGH'

Strip

class ostr(ostr):
    def strip(self, cl=None):
        return self.lstrip(cl).rstrip(cl)

    def lstrip(self, cl=None):
        res = super().lstrip(cl)
        i = self.find(res)
        return self[i:]

    def rstrip(self, cl=None):
        res = super().rstrip(cl)
        return self[0:len(res)]
my_str1 = ostr("  abc  ")
v = my_str1.strip()
assert v, v.origin == ('abc', [2, 3, 4])
my_str1 = ostr("  abc  ")
v = my_str1.lstrip()
assert (v, v.origin) == ('abc  ', [2, 3, 4, 5, 6])
my_str1 = ostr("  abc  ")
v = my_str1.rstrip()
assert (v, v.origin) == ('  abc', [0, 1, 2, 3, 4])

Expand Tabs

class ostr(ostr):
    def expandtabs(self, n=8):
        parts = self.split('\t')
        res = super().expandtabs(n)
        all_parts = []
        for i, p in enumerate(parts):
            all_parts.extend(p.origin)
            if i < len(parts) - 1:
                l = len(all_parts) % n
                all_parts.extend([p.origin[-1]] * l)
        return self.create(res, all_parts)
my_str = str("ab\tcd")
my_ostr = ostr("ab\tcd")
v1 = my_str.expandtabs(4)
v2 = my_ostr.expandtabs(4)
assert str(v1) == str(v2)
assert (len(v1), repr(v2), v2.origin) == (6, "'ab  cd'", [0, 1, 1, 1, 3, 4])
class ostr(ostr):
    def join(self, iterable):
        mystr = ''
        myorigin = []
        sep_origin = self.origin
        lst = list(iterable)
        for i, s in enumerate(lst):
            sorigin = s.origin if isinstance(s, ostr) else [
                self.UNKNOWN_ORIGIN] * len(s)
            myorigin.extend(sorigin)
            mystr += str(s)
            if i < len(lst) - 1:
                myorigin.extend(sep_origin)
                mystr += str(self)
        res = super().join(iterable)
        assert len(res) == len(mystr)
        return self.create(res, myorigin)
my_str = ostr("ab cd", origin=100)
(v1, v2), v3 = my_str.split(), 'ef'
assert (v1.origin, v2.origin) == ([100, 101], [103, 104])
v4 = ostr('').join([v2, v3, v1])
assert (
    v4, v4.origin) == (
        'cdefab', [
            103, 104, ostr.UNKNOWN_ORIGIN, ostr.UNKNOWN_ORIGIN, 100, 101])
my_str = ostr("ab cd", origin=100)
(v1, v2), v3 = my_str.split(), 'ef'
assert (v1.origin, v2.origin) == ([100, 101], [103, 104])
v4 = ostr(',').join([v2, v3, v1])
assert (v4, v4.origin) == ('cd,ef,ab',
                           [103, 104, 0, ostr.UNKNOWN_ORIGIN, ostr.UNKNOWN_ORIGIN, 0, 100, 101])

Partitions

class ostr(ostr):
    def partition(self, sep):
        partA, sep, partB = super().partition(sep)
        return (self.create(partA, self.origin[0:len(partA)]),
                self.create(sep,
                            self.origin[len(partA):len(partA) + len(sep)]),
                self.create(partB, self.origin[len(partA) + len(sep):]))

    def rpartition(self, sep):
        partA, sep, partB = super().rpartition(sep)
        return (self.create(partA, self.origin[0:len(partA)]),
                self.create(sep,
                            self.origin[len(partA):len(partA) + len(sep)]),
                self.create(partB, self.origin[len(partA) + len(sep):]))

Justify

class ostr(ostr):
    def ljust(self, width, fillchar=' '):
        res = super().ljust(width, fillchar)
        initial = len(res) - len(self)
        if isinstance(fillchar, tstr):
            t = fillchar.x()
        else:
            t = self.UNKNOWN_ORIGIN
        return self.create(res, [t] * initial + self.origin)
class ostr(ostr):
    def rjust(self, width, fillchar=' '):
        res = super().rjust(width, fillchar)
        final = len(res) - len(self)
        if isinstance(fillchar, tstr):
            t = fillchar.x()
        else:
            t = self.UNKNOWN_ORIGIN
        return self.create(res, self.origin + [t] * final)

mod

class ostr(ostr):
    def __mod__(self, s):
        # nothing else implemented for the time being
        assert isinstance(s, str)
        s_origin = s.origin if isinstance(
            s, ostr) else [self.UNKNOWN_ORIGIN] * len(s)
        i = self.find('%s')
        assert i >= 0
        res = super().__mod__(s)
        r_origin = self.origin[:]
        r_origin[i:i + 2] = s_origin
        return self.create(res, origin=r_origin)
class ostr(ostr):
    def __rmod__(self, s):
        # nothing else implemented for the time being
        assert isinstance(s, str)
        r_origin = s.origin if isinstance(
            s, ostr) else [self.UNKNOWN_ORIGIN] * len(s)
        i = s.find('%s')
        assert i >= 0
        res = super().__rmod__(s)
        s_origin = self.origin[:]
        r_origin[i:i + 2] = s_origin
        return self.create(res, origin=r_origin)
a = ostr('hello %s world', origin=100)
a
'hello %s world'
(a % 'good').origin
[100, 101, 102, 103, 104, 105, -1, -1, -1, -1, 108, 109, 110, 111, 112, 113]
b = 'hello %s world'
c = ostr('bad', origin=10)
(b % c).origin
[-1, -1, -1, -1, -1, -1, 10, 11, 12, -1, -1, -1, -1, -1, -1]

String methods that do not change origin

class ostr(ostr):
    def swapcase(self):
        return self.create(str(self).swapcase(), self.origin)

    def upper(self):
        return self.create(str(self).upper(), self.origin)

    def lower(self):
        return self.create(str(self).lower(), self.origin)

    def capitalize(self):
        return self.create(str(self).capitalize(), self.origin)

    def title(self):
        return self.create(str(self).title(), self.origin)
a = ostr('aa', origin=100).upper()
a, a.origin
('AA', [100, 101])

General wrappers

These are not strictly needed for operation, but can be useful for tracing.

def make_str_wrapper(fun):
    def proxy(*args, **kwargs):
        res = fun(*args, **kwargs)
        return res
    return proxy
import inspect
import types
ostr_members = [name for name, fn in inspect.getmembers(ostr, callable)
                if isinstance(fn, types.FunctionType) and fn.__qualname__.startswith('ostr')]

for name, fn in inspect.getmembers(str, callable):
    if name not in set(['__class__', '__new__', '__str__', '__init__',
                        '__repr__', '__getattribute__']) | set(ostr_members):
        setattr(ostr, name, make_str_wrapper(fn))

Methods yet to be translated

These methods generate strings from other strings. However, we do not have the right implementations for any of these. Hence these are marked as dangerous until we can generate the right translations.

def make_str_abort_wrapper(fun):
    def proxy(*args, **kwargs):
        raise ostr.TaintException(
            '%s Not implemented in `ostr`' %
            fun.__name__)
    return proxy
for name, fn in inspect.getmembers(str, callable):
    # Omitted 'splitlines' as this is needed for formatting output in
    # IPython/Jupyter
    if name in ['__format__', 'format_map', 'format',
                '__mul__', '__rmul__', 'center', 'zfill', 'decode', 'encode']:
        setattr(ostr, name, make_str_abort_wrapper(fn))

Checking Origins

With all this implemented, we now have full-fledged ostr strings where we can easily check the origin of each and every character.

To check whether a string originates from another string, we can convert the origin to a set and resort to standard set operations:

s = ostr("hello", origin=100)
s[1]
'e'
s[1].origin
[101]
set(s[1].origin) <= set(s.origin)
True
t = ostr("world", origin=200)
set(s.origin) <= set(t.origin)
False
u = s + t + "!"
u.origin
[100, 101, 102, 103, 104, 200, 201, 202, 203, 204, -1]
ostr.UNKNOWN_ORIGIN in u.origin
True

Privacy Leaks Revisited

Let us apply it to see whether we can come up with a satisfactory solution for checking the heartbeat() function against information leakage.

SECRET_ORIGIN = 1000

We define a "secret" that must not leak out:

secret = ostr('<again, some super-secret input>', origin=SECRET_ORIGIN)

Each and every character in secret has an origin starting with SECRET_ORIGIN:

print(secret.origin)
[1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031]

If we now invoke heartbeat() with a given string, the origin of the reply should all be UNKNOWN_ORIGIN (from the input), and none of the characters should have a SECRET_ORIGIN.

s = heartbeat('hello', 5, memory=secret)
s
'hello'
print(s.origin)
[-1, -1, -1, -1, -1]

We can verify that the secret did not leak out by formulating appropriate assertions:

assert s.origin == [ostr.UNKNOWN_ORIGIN] * len(s)
assert all(origin == ostr.UNKNOWN_ORIGIN for origin in s.origin)
assert not any(origin >= SECRET_ORIGIN for origin in s.origin)

All assertions pass, again confirming that no secret leaked out.

Let us now go and exploit heartbeat() to reveal its secrets. As heartbeat() is unchanged, it is as vulnerable as it was:

s = heartbeat('hello', 32, memory=secret)
s
'hellon, some super-secret input>'

Now, however, the reply does contain secret information:

print(s.origin)
[-1, -1, -1, -1, -1, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031]
with ExpectError():
    assert s.origin == [ostr.UNKNOWN_ORIGIN] * len(s)
Traceback (most recent call last):
  File "<ipython-input-203-6cac9e5bbde7>", line 2, in <module>
    assert s.origin == [ostr.UNKNOWN_ORIGIN] * len(s)
AssertionError (expected)
with ExpectError():
    assert all(origin == ostr.UNKNOWN_ORIGIN for origin in s.origin)
Traceback (most recent call last):
  File "<ipython-input-204-860ea80b8867>", line 2, in <module>
    assert all(origin == ostr.UNKNOWN_ORIGIN for origin in s.origin)
AssertionError (expected)
with ExpectError():
    assert not any(origin >= SECRET_ORIGIN for origin in s.origin)
Traceback (most recent call last):
  File "<ipython-input-205-9630f3080c59>", line 2, in <module>
    assert not any(origin >= SECRET_ORIGIN for origin in s.origin)
AssertionError (expected)

We can now integrate these assertions into the heartbeat() function, causing it to fail before leaking information. Additionally (or alternatively?), we can also rewrite our output functions not to give out any secret information. We will leave these two exercises for the reader.

Taint-Directed Fuzzing

The previous Taint Aware Fuzzing was a bit unsatisfactory in that we could not focus on the specific parts of the grammar that led to dangerous operations. We fix that with taint directed fuzzing using TrackingDB.

The idea here is to track the origins of each character that reaches eval. Then, track it back to the grammar nodes that generated it, and increase the probability of using those nodes again.

TrackingDB

The TrackingDB is similar to TaintedDB. The difference is that, if we find that the execution has reached the my_eval, we simply raise the Tainted.

class TrackingDB(TaintedDB):
    def my_eval(self, statement, g, l):
        if statement.origin:
            raise Tainted(statement)
        try:
            return eval(statement, g, l)
        except:
            raise SQLException('Invalid SQL (%s)' % repr(statement))

Next, we need a specially crafted fuzzer that preserves the taints.

TaintedGrammarFuzzer

We define a TaintedGrammarFuzzer class that ensures that the taints propagate to the derivation tree. This is similar to the GrammarFuzzer from the chapter on grammar fuzzers except that the origins and taints are preserved.

import random
from Grammars import START_SYMBOL
from GrammarFuzzer import GrammarFuzzer
from Parser import canonical
class TaintedGrammarFuzzer(GrammarFuzzer):
    def __init__(self,
                 grammar,
                 start_symbol=START_SYMBOL,
                 expansion_switch=1,
                 log=False):
        self.tainted_start_symbol = ostr(
            start_symbol, origin=[1] * len(start_symbol))
        self.expansion_switch = expansion_switch
        self.log = log
        self.grammar = grammar
        self.c_grammar = canonical(grammar)
        self.init_tainted_grammar()

    def expansion_cost(self, expansion, seen=set()):
        symbols = [e for e in expansion if e in self.c_grammar]
        if len(symbols) == 0:
            return 1

        if any(s in seen for s in symbols):
            return float('inf')

        return sum(self.symbol_cost(s, seen) for s in symbols) + 1

    def fuzz_tree(self):
        tree = (self.tainted_start_symbol, [])
        nt_leaves = [tree]
        expansion_trials = 0
        while nt_leaves:
            idx = random.randint(0, len(nt_leaves) - 1)
            key, children = nt_leaves[idx]
            expansions = self.ct_grammar[key]
            if expansion_trials < self.expansion_switch:
                expansion = random.choice(expansions)
            else:
                costs = [self.expansion_cost(e) for e in expansions]
                m = min(costs)
                all_min = [i for i, c in enumerate(costs) if c == m]
                expansion = expansions[random.choice(all_min)]

            new_leaves = [(token, []) for token in expansion]
            new_nt_leaves = [e for e in new_leaves if e[0] in self.ct_grammar]
            children[:] = new_leaves
            nt_leaves[idx:idx + 1] = new_nt_leaves
            if self.log:
                print("%-40s" % (key + " -> " + str(expansion)))
            expansion_trials += 1
        return tree

    def fuzz(self):
        self.derivation_tree = self.fuzz_tree()
        return self.tree_to_string(self.derivation_tree)

We use a specially prepared tainted grammar for fuzzing. We mark each individual definition, each individual rule, and each individual token with a separate origin (we chose a token boundary of 10 here, after inspecting the grammar). This allows us to track exactly which parts of the grammar was involved in the operations we are interested in.

class TaintedGrammarFuzzer(TaintedGrammarFuzzer):
    def init_tainted_grammar(self):
        key_increment, alt_increment, token_increment = 1000, 100, 10
        key_origin = key_increment
        self.ct_grammar = {}
        for key, val in self.c_grammar.items():
            key_origin += key_increment
            os = []
            for v in val:
                ts = []
                key_origin += alt_increment
                for t in v:
                    nt = ostr(t, origin=key_origin)
                    key_origin += token_increment
                    ts.append(nt)
                os.append(ts)
            self.ct_grammar[key] = os

        # a use tracking grammar
        self.ctp_grammar = {}
        for key, val in self.ct_grammar.items():
            self.ctp_grammar[key] = [(v, dict(use=0)) for v in val]

As before, we initialize the TrackingDB

trdb = TrackingDB(db.db)

Finally, we need to ensure that the taints are preserved when the tree is converted back to a string. For this, we define the tainted_tree_to_string()

class TaintedGrammarFuzzer(TaintedGrammarFuzzer):
    def tree_to_string(self, tree):
        symbol, children, *_ = tree
        e = ostr('')
        if children:
            return e.join([self.tree_to_string(c) for c in children])
        else:
            return e if symbol in self.c_grammar else symbol

We define update_grammar() that accepts a set of origins that reached the dangerous operations and the derivation tree of the original string used for fuzzing to update the enhanced grammar.

class TaintedGrammarFuzzer(TaintedGrammarFuzzer):
    def update_grammar(self, origin, dtree):
        def update_tree(dtree, origin):
            key, children = dtree
            if children:
                updated_children = [update_tree(c, origin) for c in children]
                corigin = set.union(
                    *[o for (key, children, o) in updated_children])
                corigin = corigin.union(set(key.origin))
                return (key, children, corigin)
            else:
                my_origin = set(key.origin).intersection(origin)
                return (key, [], my_origin)

        key, children, oset = update_tree(dtree, set(origin))
        for key, alts in self.ctp_grammar.items():
            for alt, o in alts:
                alt_origins = set([i for token in alt for i in token.origin])
                if alt_origins.intersection(oset):
                    o['use'] += 1

With these, we are now ready to fuzz.

def tree_type(tree):
    key, children = tree
    return (type(key), key, [tree_type(c) for c in children])
tgf = TaintedGrammarFuzzer(INVENTORY_GRAMMAR_F)
x = None
for _ in range(10):
    qtree = tgf.fuzz_tree()
    query = tgf.tree_to_string(qtree)
    assert isinstance(query, ostr)
    try:
        print(repr(query))
        res = trdb.sql(query)
        print(repr(res))
    except SQLException as e:
        print(e)
    except Tainted as e:
        print(e)
        origin = e.args[0].origin
        tgf.update_grammar(origin, qtree)
    except:
        traceback.print_exc()
        break
    print()
'select (g!=(9)!=((:)==2==9)!=J)==-7 from inventory'
Tainted[((g!=(9)!=((:)==2==9)!=J)==-7)]

'delete from inventory where ((c)==T)!=5==(8!=Y)!=-5'
Tainted[((c)==T)!=5==(8!=Y)!=-5]

'select (((w==(((X!=------8)))))) from inventory'
Tainted[((((w==(((X!=------8)))))))]

'delete from inventory where ((.==(-3)!=(((-3))))!=(S==(((n))==Y))!=--2!=N==-----0==--0)!=(((((R))))==((v)))!=((((((------2==Q==-8!=(q)!=(((.!=2))==J)!=(1)!=(((-4!=--5==J!=(((A==.)))))!=(((((0==(P!=((R))!=(((j)))!=7))))==O==K))==(q))==--1==((H)==(t)==s!=-6==((y))==R)!=((H))!=W==--4==(P==(u)==-0)!=O==((-5==-------2!=4!=U))!=-1==((((((R!=-6))))))!=1!=Z)))==(((I)!=((S))!=(-4==s)==(7!=(A))==(s)==p==((_)!=(C))==((w)))))))'
Tainted[((.==(-3)!=(((-3))))!=(S==(((n))==Y))!=--2!=N==-----0==--0)!=(((((R))))==((v)))!=((((((------2==Q==-8!=(q)!=(((.!=2))==J)!=(1)!=(((-4!=--5==J!=(((A==.)))))!=(((((0==(P!=((R))!=(((j)))!=7))))==O==K))==(q))==--1==((H)==(t)==s!=-6==((y))==R)!=((H))!=W==--4==(P==(u)==-0)!=O==((-5==-------2!=4!=U))!=-1==((((((R!=-6))))))!=1!=Z)))==(((I)!=((S))!=(-4==s)==(7!=(A))==(s)==p==((_)!=(C))==((w)))))))]

'delete from inventory where ((2)==T!=-1)==N==(P)==((((((6==a)))))!=8)==(3)!=((---7))'
Tainted[((2)==T!=-1)==N==(P)==((((((6==a)))))!=8)==(3)!=((---7))]

'delete from inventory where o!=2==---5==3!=t'
Tainted[o!=2==---5==3!=t]

'select (2) from inventory'
Tainted[((2))]

'select _ from inventory'
Tainted[(_)]

'select L!=(((1!=(Z)==C)!=C))==(((-0==-5==Q!=((--2!=(-0)==((0))==M)==(A))!=(X)!=e==(K==((b)))!=b==9==((((l)!=-7!=4)!=s==G))!=6==((((5==(((v==(((((((a!=d))==0!=4!=(4)==--1==(h)==-8!=(9)==-4)))))!=I!=-4))==v!=(Y==b)))==(a))!=((7)))))))==((4)) from inventory'
Tainted[(L!=(((1!=(Z)==C)!=C))==(((-0==-5==Q!=((--2!=(-0)==((0))==M)==(A))!=(X)!=e==(K==((b)))!=b==9==((((l)!=-7!=4)!=s==G))!=6==((((5==(((v==(((((((a!=d))==0!=4!=(4)==--1==(h)==-8!=(9)==-4)))))!=I!=-4))==v!=(Y==b)))==(a))!=((7)))))))==((4)))]

'delete from inventory where _==(7==(9)!=(---5)==1)==-8'
Tainted[_==(7==(9)!=(---5)==1)==-8]

We can now inspect our enhanced grammar to see how many times each rule was used.

tgf.ctp_grammar
{'<start>': [(['<query>'], {'use': 10})],
 '<expr>': [(['<bexpr>'], {'use': 8}),
  (['<aexpr>'], {'use': 8}),
  (['(', '<expr>', ')'], {'use': 8}),
  (['<term>'], {'use': 10})],
 '<bexpr>': [(['<aexpr>', '<lt>', '<aexpr>'], {'use': 0}),
  (['<aexpr>', '<gt>', '<aexpr>'], {'use': 0}),
  (['<expr>', '==', '<expr>'], {'use': 8}),
  (['<expr>', '!=', '<expr>'], {'use': 8})],
 '<aexpr>': [(['<aexpr>', '+', '<aexpr>'], {'use': 0}),
  (['<aexpr>', '-', '<aexpr>'], {'use': 0}),
  (['<aexpr>', '*', '<aexpr>'], {'use': 0}),
  (['<aexpr>', '/', '<aexpr>'], {'use': 0}),
  (['<word>', '(', '<exprs>', ')'], {'use': 0}),
  (['<expr>'], {'use': 8})],
 '<exprs>': [(['<expr>', ',', '<exprs>'], {'use': 0}),
  (['<expr>'], {'use': 5})],
 '<lt>': [(['<'], {'use': 0})],
 '<gt>': [(['>'], {'use': 0})],
 '<term>': [(['<number>'], {'use': 9}), (['<word>'], {'use': 9})],
 '<number>': [(['<integer>', '.', '<integer>'], {'use': 0}),
  (['<integer>'], {'use': 9}),
  (['-', '<number>'], {'use': 8})],
 '<integer>': [(['<digit>', '<integer>'], {'use': 0}),
  (['<digit>'], {'use': 9})],
 '<word>': [(['<word>', '<letter>'], {'use': 0}),
  (['<word>', '<digit>'], {'use': 0}),
  (['<letter>'], {'use': 9})],
 '<digit>': [(['0'], {'use': 2}),
  (['1'], {'use': 4}),
  (['2'], {'use': 6}),
  (['3'], {'use': 3}),
  (['4'], {'use': 2}),
  (['5'], {'use': 5}),
  (['6'], {'use': 3}),
  (['7'], {'use': 5}),
  (['8'], {'use': 6}),
  (['9'], {'use': 3})],
 '<letter>': [(['a'], {'use': 2}),
  (['b'], {'use': 1}),
  (['c'], {'use': 1}),
  (['d'], {'use': 1}),
  (['e'], {'use': 1}),
  (['f'], {'use': 0}),
  (['g'], {'use': 1}),
  (['h'], {'use': 1}),
  (['i'], {'use': 0}),
  (['j'], {'use': 1}),
  (['k'], {'use': 0}),
  (['l'], {'use': 1}),
  (['m'], {'use': 0}),
  (['n'], {'use': 1}),
  (['o'], {'use': 1}),
  (['p'], {'use': 1}),
  (['q'], {'use': 1}),
  (['r'], {'use': 0}),
  (['s'], {'use': 2}),
  (['t'], {'use': 2}),
  (['u'], {'use': 1}),
  (['v'], {'use': 2}),
  (['w'], {'use': 2}),
  (['x'], {'use': 0}),
  (['y'], {'use': 1}),
  (['z'], {'use': 0}),
  (['A'], {'use': 2}),
  (['B'], {'use': 0}),
  (['C'], {'use': 2}),
  (['D'], {'use': 0}),
  (['E'], {'use': 0}),
  (['F'], {'use': 0}),
  (['G'], {'use': 1}),
  (['H'], {'use': 1}),
  (['I'], {'use': 2}),
  (['J'], {'use': 2}),
  (['K'], {'use': 2}),
  (['L'], {'use': 1}),
  (['M'], {'use': 1}),
  (['N'], {'use': 2}),
  (['O'], {'use': 1}),
  (['P'], {'use': 2}),
  (['Q'], {'use': 2}),
  (['R'], {'use': 1}),
  (['S'], {'use': 1}),
  (['T'], {'use': 2}),
  (['U'], {'use': 1}),
  (['V'], {'use': 0}),
  (['W'], {'use': 1}),
  (['X'], {'use': 2}),
  (['Y'], {'use': 3}),
  (['Z'], {'use': 2}),
  (['_'], {'use': 3}),
  ([':'], {'use': 1}),
  (['.'], {'use': 1})],
 '<query>': [(['select ', '<exprs>', ' from ', '<table>'], {'use': 5}),
  (['select ', '<exprs>', ' from ', '<table>', ' where ', '<bexpr>'],
   {'use': 0}),
  (['insert into ',
    '<table>',
    ' (',
    '<names>',
    ') values (',
    '<literals>',
    ')'],
   {'use': 0}),
  (['update ', '<table>', ' set ', '<assignments>', ' where ', '<bexpr>'],
   {'use': 0}),
  (['delete from ', '<table>', ' where ', '<bexpr>'], {'use': 5})],
 '<table>': [(['inventory'], {'use': 0})],
 '<names>': [(['<column>', ',', '<names>'], {'use': 0}),
  (['<column>'], {'use': 0})],
 '<column>': [(['<word>'], {'use': 0})],
 '<literals>': [(['<literal>'], {'use': 0}),
  (['<literal>', ',', '<literals>'], {'use': 0})],
 '<literal>': [(['<number>'], {'use': 0}),
  (["'", '<chars>', "'"], {'use': 0})],
 '<assignments>': [(['<kvp>', ',', '<assignments>'], {'use': 0}),
  (['<kvp>'], {'use': 0})],
 '<kvp>': [(['<column>', '=', '<value>'], {'use': 0})],
 '<value>': [(['<word>'], {'use': 0})],
 '<chars>': [(['<char>'], {'use': 0}), (['<char>', '<chars>'], {'use': 0})],
 '<char>': [(['0'], {'use': 0}),
  (['1'], {'use': 0}),
  (['2'], {'use': 0}),
  (['3'], {'use': 0}),
  (['4'], {'use': 0}),
  (['5'], {'use': 0}),
  (['6'], {'use': 0}),
  (['7'], {'use': 0}),
  (['8'], {'use': 0}),
  (['9'], {'use': 0}),
  (['a'], {'use': 0}),
  (['b'], {'use': 0}),
  (['c'], {'use': 0}),
  (['d'], {'use': 0}),
  (['e'], {'use': 0}),
  (['f'], {'use': 0}),
  (['g'], {'use': 0}),
  (['h'], {'use': 0}),
  (['i'], {'use': 0}),
  (['j'], {'use': 0}),
  (['k'], {'use': 0}),
  (['l'], {'use': 0}),
  (['m'], {'use': 0}),
  (['n'], {'use': 0}),
  (['o'], {'use': 0}),
  (['p'], {'use': 0}),
  (['q'], {'use': 0}),
  (['r'], {'use': 0}),
  (['s'], {'use': 0}),
  (['t'], {'use': 0}),
  (['u'], {'use': 0}),
  (['v'], {'use': 0}),
  (['w'], {'use': 0}),
  (['x'], {'use': 0}),
  (['y'], {'use': 0}),
  (['z'], {'use': 0}),
  (['A'], {'use': 0}),
  (['B'], {'use': 0}),
  (['C'], {'use': 0}),
  (['D'], {'use': 0}),
  (['E'], {'use': 0}),
  (['F'], {'use': 0}),
  (['G'], {'use': 0}),
  (['H'], {'use': 0}),
  (['I'], {'use': 0}),
  (['J'], {'use': 0}),
  (['K'], {'use': 0}),
  (['L'], {'use': 0}),
  (['M'], {'use': 0}),
  (['N'], {'use': 0}),
  (['O'], {'use': 0}),
  (['P'], {'use': 0}),
  (['Q'], {'use': 0}),
  (['R'], {'use': 0}),
  (['S'], {'use': 0}),
  (['T'], {'use': 0}),
  (['U'], {'use': 0}),
  (['V'], {'use': 0}),
  (['W'], {'use': 0}),
  (['X'], {'use': 0}),
  (['Y'], {'use': 0}),
  (['Z'], {'use': 0}),
  (['!'], {'use': 0}),
  (['#'], {'use': 0}),
  (['$'], {'use': 0}),
  (['%'], {'use': 0}),
  (['&'], {'use': 0}),
  (['('], {'use': 0}),
  ([')'], {'use': 0}),
  (['*'], {'use': 0}),
  (['+'], {'use': 0}),
  ([','], {'use': 0}),
  (['-'], {'use': 0}),
  (['.'], {'use': 0}),
  (['/'], {'use': 0}),
  ([':'], {'use': 0}),
  ([';'], {'use': 0}),
  (['='], {'use': 0}),
  (['?'], {'use': 0}),
  (['@'], {'use': 0}),
  (['['], {'use': 0}),
  (['\\'], {'use': 0}),
  ([']'], {'use': 0}),
  (['^'], {'use': 0}),
  (['_'], {'use': 0}),
  (['`'], {'use': 0}),
  (['{'], {'use': 0}),
  (['|'], {'use': 0}),
  (['}'], {'use': 0}),
  (['~'], {'use': 0}),
  ([' '], {'use': 0}),
  (['<lt>'], {'use': 0}),
  (['<gt>'], {'use': 0})]}

From here, the idea is to focus on the rules that reached dangerous operations more often, and increase the probability of the values of that kind.

The Limits of Taint Tracking

While our framework can detect information leakage, it is by no means perfect. There are several ways in which taints can get lost and information thus may still leak out.

Conversions

We only track taints and origins through strings and characters. If we convert these to numbers (or other data), the information is lost.

As an example, consider this function, converting individual characters to numbers and back:

def strip_all_info(s):
    t = ""
    for c in s:
        t += chr(ord(c))
    return t
thello = ostr("Secret")
thello
'Secret'
thello.origin
[0, 1, 2, 3, 4, 5]

The taints and origins will not propagate through the number conversion:

thello_stripped = strip_all_info(thello)
thello_stripped
'Secret'
with ExpectError():
    thello_stripped.origin
Traceback (most recent call last):
  File "<ipython-input-223-56d5157cf575>", line 2, in <module>
    thello_stripped.origin
AttributeError: 'str' object has no attribute 'origin' (expected)

This issue could be addressed by extending numbers with taints and origins, just as we did for strings. At some point, however, this will still break down, because as soon as an internal C function in the Python library is reached, the taint will not propagate into and across the C function. (Unless one starts implementing dynamic taints for these, that is.)

Internal C libraries

As we mentioned before, calls to internal C libraries do not propagate taints. For example, while the following preserves the taints,

hello = ostr('hello', origin=100)
world = ostr('world', origin=200)
(hello + ' ' + world).origin
[100, 101, 102, 103, 104, -1, 200, 201, 202, 203, 204]

a call to a join that should be equivalent will fail.

with ExpectError():
    ''.join([hello, ' ', world]).origin
Traceback (most recent call last):
  File "<ipython-input-225-ad148b54cc0b>", line 2, in <module>
    ''.join([hello, ' ', world]).origin
AttributeError: 'str' object has no attribute 'origin' (expected)

Implicit Information Flow

Even if one could taint all data in a program, there still would be means to break information flow – notably by turning explicit flow into implicit flow, or data flow into control flow. Here is an example:

def strip_all_info_again(s):
    t = ""
    for c in s:
        if c == 'a':
            t += 'a'
        elif c == 'b':
            t += 'b'
        elif c == 'c':
            t += 'c'
    ...

With such a function, there is no explicit data flow between the characters in s and the characters in t; yet, the strings would be identical. This problem frequently occurs in programs that process and manipulate external input.

Enforcing Tainting

Both conversions and implicit information flow are one of several possibilities that taint and origin information get lost. To address the problem, the best solution is to always assume the worst from untainted strings:

  • As it comes to trust, an untainted string should be treated as possibly untrusted, and hence not relied upon unless sanitized.

  • As it comes to privacy, an untainted string should be treated as possibly secret, and hence not leaked out.

As a consequence, your program should always have two kinds of taints: one for explicitly trusted (or secret) and one for explicitly untrusted (or non-secret). If a taint gets lost along the way, you will may have to restore it from its sources – not unlike the string methods discussed above. The benefit is a trusted application, in which each and every information flow can be checked at runtime, with violations quickly discovered through automated tests.

Lessons Learned

  • String-based and character-based taints allow to dynamically track the information flow form input to the internals of a system and back to the output.

  • Checking taints allows to discover untrusted inputs and information leakage at runtime.

  • Data conversions and implicit data flow may strip taint information; the resulting untainted strings should be treated as having the worst possible taint.

  • Taints can be used in conjunction with fuzzing to provide a more robust indication of incorrect behavior than to simply rely on program crashes.

Next Steps

An even better alternative to our taint-directed fuzzing is to make use of symbolic techniques that take the semantics of the program under test into account. The chapter on flow fuzzing introduces these symbolic techniques for the purpose of exploring information flows; the subsequent chapter on symbolic fuzzing then shows how to make full-fledged use of symbolic execution for covering code. Similarly, search based fuzzing can often provide a cheaper exploration strategy.

Background

Taint analysis on Python using a library approach as we implemented in this chapter was discussed by Conti et al. [Conti et al, 2012.].

Exercises

Exercise 1: Tainted Numbers

Introduce a class tint (for tainted integer) that, like tstr, has a taint attribute that gets passed on from tint to tint.

Part 1: Creation

Implement the tint class such that taints are set:

x = tint(42, taint='SECRET')
assert x.taint == 'SECRET'

Part 2: Arithmetic expressions

Ensure that taints get passed along arithmetic expressions; support addition, subtraction, multiplication, and division operators.

y = x + 1
assert y.taint == 'SECRET'

Part 3: Passing taints from integers to strings

Converting a tainted integer into a string (using repr()) should yield a tainted string:

s = repr(x)
assert s.taint == 'SECRET'

Part 4: Passing taints from integers to strings

Converting a tainted object (with a taint attribute) to an integer should pass that taint:

password = tstr('1234', taint='NOT_EXACTLY_SECRET')
x = tint(password)
assert x == 1234
assert x.taint == 'NOT_EXACTLY_SECRET'

Exercise 2: Information Flow Testing

Generate tests that ensure a maximum of information flow, propagating specific taints as much as possible. Implement an appropriate fitness function for search-based testing and let the search-based fuzzer search for solutions.

Creative Commons License The content of this project is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. The source code that is part of the content, as well as the source code used to format and display that content is licensed under the MIT License. Last change: 2019-05-21 19:58:01+02:00CiteImprint