Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions wespeaker/bin/export_ncnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2024, Chengdong Liang(liangchengdongd@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import argparse
import os

import torch
import yaml
import ncnn
import numpy as np

from wespeaker.models.speaker_model import get_speaker_model
from wespeaker.utils.checkpoint import load_checkpoint


def get_args():
parser = argparse.ArgumentParser(description='export your script model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--output_dir', required=True, help='output dir')
args = parser.parse_args()
return args


def test_ncnn_inference(in0, ncnn_param, ncnn_bin):
outs = []

with ncnn.Net() as net:
net.load_param(ncnn_param)
net.load_model(ncnn_bin)
input_names = net.input_names()
output_names = net.output_names()
print("input_names: ", input_names)
print("output_names: ", output_names)

with net.create_extractor() as ex:

ex.input("in0", ncnn.Mat(in0.squeeze(0).numpy()).clone())
_, out0 = ex.extract("out0")
outs.append(torch.from_numpy(np.array(out0)).unsqueeze(0))
if len(output_names) > 1:
_, out1 = ex.extract("out1")
outs.append(torch.from_numpy(np.array(out1)).unsqueeze(0))

if len(outs) == 1:
return outs[0]
else:
return tuple(outs)


def main():
args = get_args()
# No need gpu for model export
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
configs['model_args']['ncnn_mode'] = True
model = get_speaker_model(configs['model'])(**configs['model_args'])
print(model)

load_checkpoint(model, args.checkpoint)
model.eval()
# Export jit torch script model
torch.manual_seed(0)
x = torch.rand(1, 200, 80, dtype=torch.float).contiguous()
script_model = torch.jit.trace(model, x)
os.makedirs(args.output_dir, exist_ok=True)
model_trace_path = os.path.join(args.output_dir, 'model.trace.pt')
script_model.save(model_trace_path)
print('Export trace model successfully, see {}'.format(model_trace_path))

os.system("pnnx {} inputshape=[1,200,80]f32".format(model_trace_path))
print('The ncnn model is saved in {} and {}'.format(
model_trace_path[:-3] + '.ncnn.param',
model_trace_path[:-3] + '.ncnn.bin'))

torch_output = model(x)
ncnn_output = test_ncnn_inference(x, model_trace_path[:-3] + '.ncnn.param',
model_trace_path[:-3] + '.ncnn.bin')
if isinstance(torch_output, tuple):
torch_output = torch_output[1]
ncnn_output = ncnn_output[1]

if np.allclose(torch_output.detach().numpy(),
ncnn_output.detach().numpy(),
rtol=1e-5,
atol=1e-2):
print("Export ncnn model successfully, "
"and the output accuracy check passed!")
else:
print("Export ncnn model successfully, but ncnn and torchscript have "
"different outputs when given the same input, please check!")


if __name__ == '__main__':
main()
50 changes: 39 additions & 11 deletions wespeaker/models/campplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def __init__(self,
padding,
dilation,
bias,
reduction=2):
reduction=2,
ncnn_mode=False):
super(CAMLayer, self).__init__()
self.linear_local = nn.Conv1d(bn_channels,
out_channels,
Expand All @@ -105,6 +106,7 @@ def __init__(self,
self.relu = nn.ReLU(inplace=True)
self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
self.sigmoid = nn.Sigmoid()
self.ncnn_mode = ncnn_mode

def forward(self, x):
y = self.linear_local(x)
Expand All @@ -127,9 +129,14 @@ def seg_pooling(self, x, seg_len: int = 100, stype: str = 'avg'):
else:
raise ValueError('Wrong segment pooling type.')
shape = seg.shape
seg = seg.unsqueeze(-1).expand(shape[0], shape[1], shape[2],
seg_len).reshape(
shape[0], shape[1], -1)
if not self.ncnn_mode:
seg = seg.unsqueeze(-1).expand(shape[0], shape[1], shape[2],
seg_len).reshape(
shape[0], shape[1], -1)
else:
seg = (seg.unsqueeze(-1) +
torch.zeros(shape[0], shape[1], shape[2], seg_len)).reshape(
shape[0], shape[1], -1).to(seg.device)
seg = seg[..., :x.shape[-1]]
return seg

Expand All @@ -144,7 +151,8 @@ def __init__(self,
stride=1,
dilation=1,
bias=False,
config_str='batchnorm-relu'):
config_str='batchnorm-relu',
ncnn_mode=False):
super(CAMDenseTDNNLayer, self).__init__()
assert kernel_size % 2 == 1, 'Expect equal paddings, \
but got even kernel size ({})'.format(kernel_size)
Expand All @@ -158,7 +166,8 @@ def __init__(self,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
bias=bias,
ncnn_mode=ncnn_mode)

def bn_function(self, x):
return self.linear1(self.nonlinear1(x))
Expand All @@ -180,7 +189,8 @@ def __init__(self,
stride=1,
dilation=1,
bias=False,
config_str='batchnorm-relu'):
config_str='batchnorm-relu',
ncnn_mode=False):
super(CAMDenseTDNNBlock, self).__init__()
for i in range(num_layers):
layer = CAMDenseTDNNLayer(in_channels=in_channels +
Expand All @@ -191,7 +201,8 @@ def __init__(self,
stride=stride,
dilation=dilation,
bias=bias,
config_str=config_str)
config_str=config_str,
ncnn_mode=ncnn_mode)
self.add_module('tdnnd%d' % (i + 1), layer)

def forward(self, x):
Expand Down Expand Up @@ -338,7 +349,8 @@ def __init__(self,
growth_rate=32,
bn_size=4,
init_channels=128,
config_str='batchnorm-relu'):
config_str='batchnorm-relu',
ncnn_mode=False):
super(CAMPPlus, self).__init__()

self.head = FCM(block=BasicResBlock,
Expand Down Expand Up @@ -367,7 +379,8 @@ def __init__(self,
bn_channels=bn_size * growth_rate,
kernel_size=kernel_size,
dilation=dilation,
config_str=config_str)
config_str=config_str,
ncnn_mode=ncnn_mode)
self.xvector.add_module('block%d' % (i + 1), block)
channels = channels + num_layers * growth_rate
self.xvector.add_module(
Expand All @@ -381,7 +394,8 @@ def __init__(self,
self.xvector.add_module('out_nonlinear',
get_nonlinear(config_str, channels))

self.pool = getattr(pooling_layers, pooling_func)(in_dim=channels)
self.pool = getattr(pooling_layers, pooling_func)(in_dim=channels,
ncnn_mode=ncnn_mode)
self.pool_out_dim = self.pool.get_out_dim()
self.xvector.add_module('stats', self.pool)
self.xvector.add_module(
Expand Down Expand Up @@ -411,6 +425,20 @@ def forward(self, x):
num_params = sum(param.numel() for param in model.parameters())
print("{} M".format(num_params / 1e6))

state_dict = model.state_dict()
model_ncnn = CAMPPlus(feat_dim=80,
embed_dim=512,
pooling_func='TSTP',
ncnn_mode=True)
model_ncnn.eval()
model_ncnn.load_state_dict(state_dict)
out_ncnn = model_ncnn(x)

torch.testing.assert_allclose(out.detach().numpy(),
out_ncnn.detach().numpy(),
rtol=1e-5,
atol=1e-3)

# from thop import profile
# x_np = torch.randn(1, 200, 80)
# flops, params = profile(model, inputs=(x_np, ))
Expand Down
45 changes: 35 additions & 10 deletions wespeaker/models/ecapa_tdnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def __init__(self,
embed_dim=192,
pooling_func='ASTP',
global_context_att=False,
emb_bn=False):
emb_bn=False,
ncnn_mode=False):
super().__init__()

self.layer1 = Conv1dReluBn(feat_dim,
Expand Down Expand Up @@ -194,7 +195,9 @@ def __init__(self,
out_channels = 512 * 3
self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1)
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=out_channels, global_context_att=global_context_att)
in_dim=out_channels,
global_context_att=global_context_att,
ncnn_mode=ncnn_mode)
self.pool_out_dim = self.pool.get_out_dim()
self.bn = nn.BatchNorm1d(self.pool_out_dim)
self.linear = nn.Linear(self.pool_out_dim, embed_dim)
Expand All @@ -221,44 +224,58 @@ def forward(self, x):
return out


def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func='ASTP', emb_bn=False):
def ECAPA_TDNN_c1024(feat_dim,
embed_dim,
pooling_func='ASTP',
emb_bn=False,
ncnn_mode=False):
return ECAPA_TDNN(channels=1024,
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
emb_bn=emb_bn)
emb_bn=emb_bn,
ncnn_mode=ncnn_mode)


def ECAPA_TDNN_GLOB_c1024(feat_dim,
embed_dim,
pooling_func='ASTP',
emb_bn=False):
emb_bn=False,
ncnn_mode=False):
return ECAPA_TDNN(channels=1024,
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
global_context_att=True,
emb_bn=emb_bn)
emb_bn=emb_bn,
ncnn_mode=ncnn_mode)


def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func='ASTP', emb_bn=False):
def ECAPA_TDNN_c512(feat_dim,
embed_dim,
pooling_func='ASTP',
emb_bn=False,
ncnn_mode=False):
return ECAPA_TDNN(channels=512,
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
emb_bn=emb_bn)
emb_bn=emb_bn,
ncnn_mode=ncnn_mode)


def ECAPA_TDNN_GLOB_c512(feat_dim,
embed_dim,
pooling_func='ASTP',
emb_bn=False):
emb_bn=False,
ncnn_mode=False):
return ECAPA_TDNN(channels=512,
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
global_context_att=True,
emb_bn=emb_bn)
emb_bn=emb_bn,
ncnn_mode=ncnn_mode)


if __name__ == '__main__':
Expand All @@ -273,6 +290,14 @@ def ECAPA_TDNN_GLOB_c512(feat_dim,
num_params = sum(param.numel() for param in model.parameters())
print("{} M".format(num_params / 1e6))

model_ncnn_mode = ECAPA_TDNN_GLOB_c512(feat_dim=80,
embed_dim=256,
pooling_func='ASTP',
ncnn_mode=True)
model_ncnn_mode.eval()
out_ncnn_mode = model_ncnn_mode(x)
torch.testing.assert_allclose(out, out_ncnn_mode, rtol=1e-5, atol=1e-3)

# from thop import profile
# x_np = torch.randn(1, 200, 80)
# flops, params = profile(model, inputs=(x_np, ))
Expand Down
Loading