flask项目单元测试实践
Test early. Test often. Test automatically. Tests that run with every build are much more effective than test plans that sit on a shelf.
最近开始做公司内部项目,CRM系统(客户关系管理), 用的flask+python3.5.2。我以前闲着没事逛逛github的时候发现了这个cookiecutter-flask,就是用来生成一个项目模板的东西,直接帮你生成项目总体框架还有README文件,还是比较方便的,直接填写逻辑即可。以前并没有使用flask的经验,这次也是边摸索一边使用(好在没碰到坑),主要想记录下关于单元测试的东西。crm主要是crud操作,这次我比较重视测试代码编写,web项目单元测试需要处理数据库交互,模拟请求,模拟登录,表单提交等操作,如何编写易于构建和执行的单元测试也是需要注意的地方。新项目统一用flask+python3.5.2(python3库的支持比我想象中快,目前使用的依赖中都支持python3),前端使用vue,前后分离,后端使用flask-restful写接口。贴出来一些代码,如果写得不合适的正好可以给我指正下:)
单元测试的必要性
之前曾经写过一篇讲单元测试的,正好最近也在实践和摸索。我似乎有种洁癖,就是我会严格遵守流程性的东西,比如测试,注释和文档等。目前就职的公司在我接手项目的时候是没有一行单元测试的,我挺诧异的。我大概也能估计到目前国内的python项目团队很多是不太规范的(比如没有单元测试、代码不符合规范、没有持续继承等)。当然流程不够规范可能不会有什么大问题,但是绝对会给代码维护造成困难,我是踩了坑的,所以要保持谨慎。虽然这次工期比较紧,半个月内搞出来一个CRM系统,但是目前这一周多的进度还是严格遵守了规范并完善了测试,并且进展还是挺顺利的,感觉单元测试确实能减少bug出现率。至于会不会浪费开发和维护时间,还需要自己权衡。至少前端不用频繁抱怨接口怎么又500啦。
写单元测试会降低生产力吗?
可能有些人不写单元测试有个重要的理由就是会延误工期。当然你要是问我究竟会不会降低生产力,我只能悲剧滴说我不知道。公司永远不会给你时间让你在两种开发方式上实验,所以很难衡量(算上开发、修改bug和维护等的时间)。不过从我这次做项目的经验上来看,至少不会耗费太多时间,而单元测试的好处是显而易见的:减少bug,保证重构不会破坏代码,简洁的设计等。目前几乎国内所有python工程师都是自学成才,学习能力是有保证的,不过职业素养就难说了。写代码有时候并不难,但是维护成本却很高。我觉得写单元测试和你写注释一样,什么时候写,写什么都需要有良好的判断力,我们需要的是质量,而不是数量。我刚学python的时候就是习惯写个函数,print下看看结果,觉得没错就认为可以工作了,现在的习惯就是把期望结果写在单元测试里头,用assert判断是否符合预期。当然,最重要的还是尽量编写清晰易懂的代码。
flask单元测试
由于我直接偷懒使用了[cookiecutter-flask]生成框架,自带了一个tests文件夹,我就直接照葫芦画瓢就好。首先在tests文件夹下有一个py.test使用的conftest.py文件(推荐你使用pytest做测试,相当便捷)
# -*- coding: utf-8 -*-
"""Defines fixtures available to all tests.
http://doc.pytest.org/en/latest/fixture.html?highlight=fixture
"""
import pytest
from webtest import TestApp
from crm_backend.app import create_app
from crm_backend.database import db as _db
from crm_backend.settings import TestConfig
from .factories import UserFactory
@pytest.yield_fixture(scope='function')
def app():
"""An application for the tests."""
_app = create_app(TestConfig)
ctx = _app.test_request_context()
ctx.push()
yield _app
ctx.pop()
@pytest.yield_fixture
def client(app):
"""A Flask test client. An instance of :class:`flask.testing.TestClient`
by default.
"""
with app.test_client() as client:
yield client
@pytest.fixture(scope='function')
def testapp(app):
"""A Webtest app."""
return TestApp(app)
@pytest.yield_fixture(scope='function')
def db(app):
"""A database for the tests."""
_db.app = app
with app.app_context():
_db.create_all()
yield _db
# Explicitly close DB connection
_db.session.close()
_db.drop_all()
@pytest.fixture
def user(db):
"""A user for the tests."""
user = UserFactory(password='myprecious')
db.session.commit()
return user
对于普通的python函数或者类,可以直接使用简单的test函数,由于编写的是web项目,麻烦的地方就在于和数据库以及后端请求的交互。在cookiecutter中使用了pytest的fixture特性来处理和数据库的交互问题,使用了webtest库来处理请求问题。分别来看看如何测试Model和View层,我这里使用了flask restful,所以改成了api层。
使用py.test测试Model层
下边是cookiecutter-flask自动生成的关于user的Model单元测试。这里有一点需要注意,测试类TestUser使用了fixture db,这个fixture在conftest.py中定义的,使用的测试配置 SQLALCHEMY_DATABASE_URI = ‘sqlite:///:memory:’,所有操作都是在内存中进行,db使用这个模拟的sqllite内存数据库。其他貌似也没啥好说的了,都是基本的crud操作,照着写测试就行,没啥好说的:
# -*- coding: utf-8 -*-
"""Model unit tests."""
import datetime as dt
import pytest
from crm_backend.user.models import Role, User
from .factories import UserFactory
@pytest.mark.usefixtures('db')
class TestUser:
"""User tests."""
def test_get_by_id(self):
"""Get user by ID."""
user = User('foo', 'foo@bar.com')
user.save()
retrieved = User.get_by_id(user.id)
assert retrieved == user
def test_created_at_defaults_to_datetime(self):
"""Test creation date."""
user = User(username='foo', email='foo@bar.com')
user.save()
assert bool(user.created_at)
assert isinstance(user.created_at, dt.datetime)
测试flask接口
这里使用的是WebTest这个库进行测试,而没有使用flask自带的test_client,WebTest还是比较方便的,常见的也就是get、post、put方法和请求数据的提交,也比较简单,代码见示例:
# -*- coding: utf-8 -*-
"""
flask flask_restful api的单元测试
"""
import pytest
from crm_backend.extensions import api
from crm_backend.advertiser.api import (
AdvertiserListApi, AdvertiserApi,
BusinessLeadListApi,
)
from crm_backend.advertiser.models import (
Advertiser,
)
from crm_backend.employee.models import (
Employee,
)
@pytest.mark.usefixtures('db')
class TestAdvertiserListApi:
def test_get(self, testapp):
self.test_post(testapp) # 先创建一个advertiser
url = api.url_for(AdvertiserListApi)
res = testapp.get(
url,
{
# 'fields': "id,name",
# 'filter': """[{"field":"status","op":"eq","q":1}]""",
'limit': 1000,
# 'order': "",
'page': 1
},
expect_errors=True
)
assert len(res.json['data']['items']) == 1
def test_post(self, testapp):
url = api.url_for(AdvertiserListApi)
bd = Employee.create(
name='e1', email='e1@bar.com', password='foobarbaz123',
team=Employee.TeamEnum.__dict__['CN-Beijing1'],
is_leader=True, # set leader
role=Employee.RoleEnum.BD
)
res = testapp.post_json(
url,
{
'name': 'advertiser_wang',
'contact_name': 'xiaoliu',
'phone': '18818881888',
'email': 'tes@qq.com',
},
expect_errors=True
)
a = Advertiser.get_by_id(1)
assert res.json['data']['id'] == a.id
assert a.name == 'advertiser_wang'
assert a.bd == bd
assert a.is_client
return a
@pytest.mark.usefixtures('db')
class TestAdvertiserApi:
def test_get(self, testapp):
a = Advertiser.create(name='adervertiser_wang')
url = api.url_for(AdvertiserApi, advertiser_id=a.id)
res = testapp.get(url)
assert res.json['id'] == str(a.id)
def test_put(self, testapp):
a = TestAdvertiserListApi().test_post(testapp) # 先创建个用户再更新
url = api.url_for(AdvertiserApi, advertiser_id=a.id)
res = testapp.put_json(
url,
{
'name': 'new_advertiser_wang',
'contact_name': 'xiaoliu',
'phone': '18818881888',
'email': 'tes@qq.com',
},
expect_errors=True
)
# 测试名称已经更新
assert Advertiser.get_by_id(a.id).name == 'new_advertiser_wang'
使用marshmallow.Schema 序列化返回数据
marshmallow is an ORM/ODM/framework-agnostic library for converting complex datatypes, such as objects, to and from native Python datatypes.
阅读flask restful文档的时候发现提到了这么个marshmallow东西我就直接在项目中使用了。
在做后台接口时,一般会碰到两个问题,一个就是参数(表单)验证,还有一个就是数据返回。参数或者表单验证都可以用wtforms完成,或者可以尝试flask eve作者写的看门狗 Cerberus,这个Cerberus是专门用来搞字段校验的,不涉及表单。数据返回可能不同项目有不同的做法。
marshmallow的作用就是用来序列化自定义的一些Python类实例。比如我们从数据库用sqlalchemy查到一个对象列表以后,需要按照指定格式返回前端需要的数据和类型,之前的做法都是自己用函数转成个dict,现在这种模式化的东西可以直接使用marshmalow里的Schema来做,而且非常灵活,需要返回不同格式或者类型的数据直接可以自定义一个schema解决。给个官方文档的例子:
from datetime import date
from marshmallow import Schema, fields, pprint
class ArtistSchema(Schema):
name = fields.Str()
class AlbumSchema(Schema):
title = fields.Str()
release_date = fields.Date()
artist = fields.Nested(ArtistSchema())
bowie = dict(name='David Bowie')
album = dict(artist=bowie, title='Hunky Dory', release_date=date(1971, 12, 17))
schema = AlbumSchema()
result = schema.dump(album)
pprint(result.data, indent=2)
# { 'artist': {'name': 'David Bowie'},
# 'release_date': '1971-12-17',
# 'title': 'Hunky Dory'}
实际上我感觉和最近比较火的graphql有点像,通过定义一系列查询模式直接返回数据。这样我们就不用自己转成dict了,不直观也不够通用。使用这种Schema以后你就可以写个统一的查询函数了,需要不同的数据格式只要把Schema类作为参数传给函数就好,我甚至尝试用一个统一的分页查询函数解决了所有Model的分页查询和过滤问题。
sqlalchemy使用的一些注意事项
数据库一直是我的弱项,还好这次没碰到啥坑问题。
- 涉及不同时区的时候统一存储utc时间,然后根据不同时区转化。可以看下flask_babel模块,用来处理国际化问题的。
- 涉及到金钱相关的数据最好用decimal处理,不要用float造成精度损失。
- 最好直接使用date和datetime类型,不要存储字符串日期。
- 建字段的时候考虑下db.Column的参数default, index, nullable, primary_key, unique哪些约束会用得上。
- 使用不使用外键?我看阿里的java规范中不允许使用外键与级联,外键和级联更新只适合单机低并发,不适合分布式、高并发集群。我搜了下网上的信息,发现争论还是很多的。我觉得做内部系统像是erp或者crm等应用还是用外键比较好。
- 创建数据库使用utf8,
CREATE DATABASE IF NOT EXISTS my_db default charset utf8 COLLATE utf8_general_ci;
- 连接url使用
mysql://root:root@127.0.0.1:3306/my_db?charset=utf8
。不用担心乱码问题了
增强flask_sqlalchemy自带的Model类
cookiecutter-flask生成的框架里边自带了一个CRUDMixin类,用来给Model增加常用的增删改查,我稍微加了几个函数用来解决一些通用的查询。比如我的query_paginate_and_dump一个函数解决了几乎大部分的查询问题。(借鉴了他人的一些代码)
# -*- coding: utf-8 -*-
"""Database module, including the SQLAlchemy database object and DB-related
utilities."""
import datetime as dt
from pprint import pformat
from marshmallow import Schema
from sqlalchemy import desc, or_
from sqlalchemy.sql.sqltypes import Date, DateTime
from sqlalchemy.orm import relationship
from werkzeug import cached_property
from .compat import basestring
from .extensions import db
from .utils import date_str_to_obj, datetime_str_to_obj
# Alias common SQLAlchemy names
Column = db.Column
relationship = relationship
OPERATOR_FUNC_DICT = {
'=': (lambda cls, k, v: getattr(cls, k) == v),
'==': (lambda cls, k, v: getattr(cls, k) == v),
'eq': (lambda cls, k, v: getattr(cls, k) == v),
'!=': (lambda cls, k, v: getattr(cls, k) != v),
'ne': (lambda cls, k, v: getattr(cls, k) != v),
'neq': (lambda cls, k, v: getattr(cls, k) != v),
'>': (lambda cls, k, v: getattr(cls, k) > v),
'gt': (lambda cls, k, v: getattr(cls, k) > v),
'>=': (lambda cls, k, v: getattr(cls, k) >= v),
'gte': (lambda cls, k, v: getattr(cls, k) >= v),
'<': (lambda cls, k, v: getattr(cls, k) < v),
'lt': (lambda cls, k, v: getattr(cls, k) < v),
'<=': (lambda cls, k, v: getattr(cls, k) <= v),
'lte': (lambda cls, k, v: getattr(cls, k) <= v),
'or': (lambda cls, k, v: or_(getattr(cls, k) == value for value in v)),
'in': (lambda cls, k, v: getattr(cls, k).in_(v)),
'nin': (lambda cls, k, v: ~getattr(cls, k).in_(v)),
'like': (lambda cls, k, v: getattr(cls, k).like('%{}%'.format(v))),
'nlike': (lambda cls, k, v: ~getattr(cls, k).like(v)),
'+': (lambda cls, k, v: getattr(cls, k) + v),
'incr': (lambda cls, k, v: getattr(cls, k) + v),
'-': (lambda cls, k, v: getattr(cls, k) - v),
'decr': (lambda cls, k, v: getattr(cls, k) - v),
}
def parse_operator(cls, filter_name_dict):
""" 用来返回sqlalchemy query对象filter使用的表达式
Args:
filter_name_dict (dict): 过滤条件dict
{
'last_name': {'eq': 'wang'}, # 如果是dic使用key作为操作符
'age': {'>': 12}
}
Returns:
binary_expression_list (lambda list)
"""
def _change_type(cls, field, value):
""" 有些表字段比如DateTime类型比较的时候需要转换类型,
前端传过来的都是字符串,Date等类型没法直接相比较,需要转成Date类型
Args:
cls (class): Model class
field (str): Model class field
value (str): value need to compare
"""
field_type = getattr(cls, field).type
if isinstance(field_type, Date):
return date_str_to_obj(value)
elif isinstance(field_type, DateTime):
return datetime_str_to_obj(value)
else:
return value
binary_expression_list = []
for field, op_dict in filter_name_dict.items():
for op, op_val in op_dict.items():
op_val = _change_type(cls, field, op_val)
if op in OPERATOR_FUNC_DICT:
binary_expression_list.append(
OPERATOR_FUNC_DICT[op](cls, field, op_val)
)
return binary_expression_list
class CRUDMixin(object):
"""Mixin that adds convenience methods for
CRUD (create, read, update, delete) operations."""
@classmethod
def create(cls, **kwargs):
"""Create a new record and save it the database."""
instance = cls(**kwargs)
return instance.save()
@classmethod
def create_from_dict(cls, d):
"""Create a new record and save it the database."""
assert isinstance(d, dict)
instance = cls(**d)
return instance.save()
def update(self, commit=True, **kwargs):
"""Update specific fields of a record."""
for attr, value in kwargs.items():
setattr(self, attr, value)
return commit and self.save() or self
def save(self, commit=True):
"""Save the record."""
db.session.add(self)
if commit:
db.session.commit()
return self
def delete(self, commit=True):
"""Remove the record from the database."""
db.session.delete(self)
return commit and db.session.commit()
def to_dict(self, fields_list=None):
"""
Args:
fields (str list): 指定返回的字段
"""
column_list = fields_list or [
column.name for column in self.__table__.columns
]
return {
column_name: getattr(self, column_name)
for column_name in column_list
}
@classmethod
def create_or_update(cls, query_dict, update_dict=None):
instance = db.session.query(cls).filter_by(**query_dict).first()
if instance: # update
if update_dict is not None:
return instance.update(**update_dict)
else:
return instance
else: # create new instance
query_dict.update(update_dict or {})
return cls.create(**query_dict)
@classmethod
def query_paginate(cls, page=1, limit=20, fields=None, order_by_list=None,
filter_name_dict=None):
""" 通用的分页查询函数
Args:
page (int): 页数
limit (int): 每页个数
order_by_list (tuple list): 用来指定排序的字段,可以传多个
[ ('id', 1), ('name', -1) ],1表示正序,-1表示逆序
or
[ ('id', 'asc'), ('name', 'desc') ],1表示正序,-1表示逆序
filter_name_dict (dict): 过滤条件,使用字典表示,使用字段名作为key,value
是{'operator': to_compare_value}, e.g.:
{
'last_name': {'eq': 'wang'}, # 如果是dic使用key作为操作符
'age': {'>': 12}
}
Returns:
if fields is not None:
(keytuple_list, total_cnt) (tuple)
else:
(instance_list, total_cnt) (tuple)
前段查询参数规范:
request.args 示例:
ImmutableMultiDict([('limit', '10'), ('page', '1'), ('filter', '[{"field":"name","op":"eq","q":"g"},{
"field":"id","op":"gt","q":"5"
}]')])
page: 页码
limit: 每页限制
order: 顺序,取值"asc" "desc". """'name', 'asc', 'model','desc'"""
fields: 需要返回的字段
filter: 过滤条件:
{
field: 需要过滤的字段
op: 过滤操作符,支持"eq","neq","gt","gte","lt","lte","in","nin","like"
q: 过滤值
}
"""
fields = (
[getattr(cls, column) for column in fields] if fields is not None
else None
)
if fields:
query = db.session.query(*fields)
else:
query = db.session.query(cls)
if order_by_list:
for (field, order) in order_by_list:
query = (
query.order_by(getattr(cls, field)) if order == 1 else
query.order_by(desc(getattr(cls, field)))
)
if filter_name_dict:
p = parse_operator(cls, filter_name_dict)
query = query.filter(*p)
total_cnt = query.count()
start = (page-1) * limit
query = query.offset(start).limit(limit)
instance_or_keytuple_list = query.all()
return instance_or_keytuple_list, total_cnt
@classmethod
def dump_schema(cls, items, fields, schema_class):
""" 用来序列化从数据库查询出来的对象
Args:
items (instance list): obj list query from mysql
fields (str list): fields need to dump
schema_class (marshmallow.Schema): marshmallow.Schema class
Returns:
items, err
"""
fields = (
fields if fields else list(schema_class._declared_fields.keys())
)
schema = schema_class(many=True, only=fields)
items, err = schema.dump(items)
return items, err
@classmethod
def query_paginate_and_dump_schema(cls, page=1, limit=20, fields=None,
order_by_list=None,
filter_name_dict=None,
schema_class=None):
""" 分页查询并且返回dump后的对象,可以解决大部分查询问题 """
assert schema_class
items, total_cnt = cls.query_paginate(
page, limit, fields, order_by_list, filter_name_dict
)
items, err = cls.dump_schema(items, fields, schema_class)
return items, total_cnt
def __repr__(self):
return pformat(self.to_dict())
@cached_property
def column_name_set(self):
return set([column.name for column in self.__table__.columns])
@classmethod
def get_common_fields(cls, fields=None):
""" 防止传过来的fields含有该Model没有的字段 """
if not fields:
return []
table_fields_set = set(
[column.name for column in cls.__table__.columns]
)
return list(table_fields_set & set(fields))
class Model(CRUDMixin, db.Model):
"""Base model class that includes CRUD convenience methods."""
__abstract__ = True
暂时就更新这么多吧,flask用得不多,可能还有一些坑要踩,在此记录一下。
摸索与总结:
- 不要因为进度紧而违背原则,你可能永远无法衡量违背原则浪费的时间,一时偷懒会付出更多维护代价,得不偿失。
- 使用pytest的fixture特性处理和数据库的交互。
- 使用webtest进行view或者api层的接口调用测试。
- 单元测试代码不要复杂,以免在测试中引入bug。如果觉得麻烦,应该在你觉得必要时候写单元测试。
- 单元测试一定要易于构造和执行(这一点最最重要)。单元测试只在内存中进行,在这里我的单元测试配置使用的是sqllite的memory配置, SQLALCHEMY_DATABASE_URI = ‘sqlite:///:memory:’, 甚至都没有使用过mysql。好处就是到哪里都可以运行测试,不要给测试引入复杂性,并且提升了测试速度。和前端对接的时候测试数据库使用另一套配置,不要和单元测试混用。只要遵循需要时创建,测试完删除就完全可以做到只在内存中执行单元测试。如果需要生成测试数据,在TestClass里写一个函数专门用来生成测试数据,测试时候调用,而不要使用真实的数据库给构造测试带来麻烦,否则可能会引入很多垃圾测试数据,并且构建和运行单元测试变得困难.
- 可以使用fswatch等命令工具监控文件变动,每次更改都会触发重跑相关测试,保证不会改出问题。我现在写代码就是开了多窗口,每次改动我都会看右边测试是否有红色的Failed的结果。方便我快速重构,效果还是挺赞的。相当于本地的持续集成。
- 新项目可以直接使用cookiecutter-flask生成框架,你只需要快速填写业务逻辑,无需过度关心代码结构组织。
- 处理返回数据可以使用marshmallow.Schema解决,更加灵活通用、直观。甚至不用写返回值文档,直接看Schema结构就行。除非你对速度有变态要求,增加一层抽象会损失部分性能。如果序列化有性能瓶颈,可以尝试下ujson,我看到很多benchmark结果ujson都是最快的。
- 写测试不是走走流程和形式,而是让你改变”写一点代码print看结果就觉得OK了”的不良编码习惯。这一点对于大型项目至关重要.步子不要迈大了,写一个测一个,防止组合起来以后难以定位错误。
- 单元测试里少用甚至不要用print,你临时看看结果可以用,最好多使用assert语句,把你期望的结果用assert语句表示出来,print多了会扰乱输出,当测试函数多了的时候会给验证结果代码麻烦。
- 出现bug修复后把单元测试用例加到原有用例中,防止再次出错。(输入组合,边界条件和异常)
- 一开始就要保证项目质量,防止积累技术债务给维护带来负担(单元测试让我可以放心大胆地快速重构)
- 警惕难以测试的代码,难以测试的代码往往需要重构。
总得来说用flask开发还是一件挺愉快的事情,丰富的插件支持,cookiecutter-flask也是个很方便的项目生成框架,pandas可以很方便滴处理报表,做东西还算比较快。希望本篇涉及到的一些东西可以让你的开发更高效,如果对代码有什么建议也欢迎沟通😄.另外有个惊喜就是python3比我想象中成熟,Python 3 Readiness列出PyPI上最常用的已经支持了python3的库,如果新项目没有历史包袱,完全可以尝试python3.5,下边是我们项目使用的依赖,都支持python3了(python3.5使用没问题),暂时没有碰到任何问题(试水成功以后新项目都开始用py3)。同时也可以尝试下python3.5的类型注解、asyncio等新玩意。
Flask-Admin==1.4.2
Flask-Assets==0.12
Flask-Babel==0.11.1
Flask-Bcrypt==0.7.1
Flask-DebugToolbar==0.10.0
Flask-Login==0.3.2
Flask-Mail==0.9.1
Flask-Migrate==2.0.0
Flask-RESTful==0.3.5
Flask-Redis==0.3.0
Flask-SQLAlchemy==2.1
Flask-WTF==0.12
Flask==0.11.1
Jinja2==2.8
MarkupSafe==0.23
Pillow==3.4.2
SQLAlchemy==1.0.12
WTForms==2.1
WebTest==2.0.20
Werkzeug==0.11.4
boto3==1.4.1
coloredlogs==5.0
facebookads==2.8.1
factory-boy==2.6.1
flake8-blind-except==0.1.0
flake8-debugger==1.4.0
flake8-docstrings==0.2.5
flake8-isort==1.2
flake8-quotes==0.2.4
flake8==2.5.4
flask-shell-ipython==0.2.2
fluent-logger==0.4.4
google-api-python-client==1.6.1
googleads==4.6.1
ipdb==0.10.1
isort==4.2.2
itsdangerous==0.24
marshmallow==2.10.3
mockredispy==2.9.3
mysqlclient==1.3.9
openpyxl==2.4.1
pandas==0.18.1
pep8-naming==0.3.3
psycopg2==2.6.1
psycopg2==2.6.1
pytest==2.9.0
python-dateutil==2.4.2
raven==5.31.0
requests==2.10.0
unicodecsv==0.14.1
uwsgi==2.0.14
xlrd==1.0.0