-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpg.py
More file actions
116 lines (99 loc) · 3.63 KB
/
pg.py
File metadata and controls
116 lines (99 loc) · 3.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import psycopg2
import psycopg2.extras
import os
import re
DATABASE_URL = os.environ.get('DATABASE_URL', 'postgresql://postgres:@localhost/python-rest-api')
conn = psycopg2.connect(DATABASE_URL)
conn.autocommit = True
def execute(*args):
cur = conn.cursor()
cur.execute(*args)
return cur
def query_tuple(*args):
cur = conn.cursor()
cur.execute(*args)
return cur.fetchall()
def query(*args):
cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
cur.execute(*args)
return list(map(dict, cur.fetchall()))
def query_one(*args):
rows = query(*args)
return rows[0] if len(rows) > 0 else None
def is_valid_column(column):
return re.match('\A[a-zA-Z0-9_]+\Z', column)
def assert_valid_columns(columns):
# sanity check columns to protect against SQL injection
invalid_columns = [c for c in columns if not is_valid_column(c)]
if invalid_columns:
raise Exception(f'Invalid column names: {invalid_columns}')
def where_sql(filter):
if not filter:
return ('', ())
columns = filter.keys()
def clause(column):
if filter[column]['op'] == 'contains':
return f'{column} like %s'
elif filter[column]['op'] == 'lt':
return f'{column} < %s'
elif filter[column]['op'] == 'gt':
return f'{column} > %s'
else:
return f'{column} = %s'
def sql_value(column):
value = filter[column]['value']
if filter[column]['op'] == 'contains':
return f'%{value}%'
else:
return value
clauses = [clause(column) for column in columns]
sql = 'WHERE ' + ' and '.join(clauses)
values = tuple([sql_value(column) for column in columns])
return (sql, values)
def order_sql(sort):
if not sort:
return ''
def parse_order(item):
direction = 'DESC' if item.startswith('-') else 'ASC'
name = item[1:] if item.startswith('-') else item
return {'direction': direction, 'name': name}
columns = [parse_order(item) for item in sort.split(',')]
assert_valid_columns([c['name'] for c in columns])
return 'ORDER BY ' + ', '.join([f'{c["name"]} {c["direction"]}' for c in columns])
#############################################################
#
# Database Interface
#
#############################################################
id_json_schema = {'type': 'integer', 'minimum': 1, 'x-meta': {'writable': False}}
def count(table_name, filter=None):
(where_clauses, where_values) = where_sql(filter)
return query_one(f'select count(*) from {table_name} {where_clauses}', where_values)['count']
def find(table_name, limit=100, offset=0, sort=None, filter=None):
(where_clauses, where_values) = where_sql(filter)
values = where_values + (limit, offset)
sql = f'select * from {table_name} {where_clauses} {order_sql(sort)} LIMIT %s OFFSET %s'
print(f'find sql={sql} values={values}')
return query(sql, values)
def find_one(table_name, id):
return query_one(f'select * from {table_name} where id = %s', [id])
def create(table_name, doc):
columns = list(doc.keys())
assert_valid_columns(columns)
values = [doc[k] for k in columns]
interpolate_values = ['%s' for _ in values]
sql = f'INSERT INTO {table_name} ({", ".join(columns)}) VALUES ({", ".join(interpolate_values)}) RETURNING id'
print(sql)
cur = execute(sql, values)
id = cur.fetchone()[0]
return id
def update(table_name, id, doc):
columns = list(doc.keys())
assert_valid_columns(columns)
interpolate_values = [f'{c} = %s' for c in columns]
values = [doc[k] for k in columns] + [id]
sql = f'UPDATE {table_name} SET {", ".join(interpolate_values)} where id = %s'
print(sql)
return execute(sql, values)
def delete(table_name, id):
return execute(f'DELETE from {table_name} where id = %s', [id])