Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ set(spirv-cross-cpp-sources

set(spirv-cross-msl-sources
${CMAKE_CURRENT_SOURCE_DIR}/spirv_msl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/spirv_msl_vertex_loader.cpp
${CMAKE_CURRENT_SOURCE_DIR}/spirv_msl.hpp)

set(spirv-cross-hlsl-sources
Expand Down
203 changes: 203 additions & 0 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,9 @@ struct CLIArguments
bool msl_ios_use_simdgroup_functions = false;
bool msl_emulate_subgroups = false;
uint32_t msl_fixed_subgroup_size = 0;
CompilerMSL::Options::IndexType msl_vertex_index_type = CompilerMSL::Options::IndexType::None;
bool msl_use_pixel_type_loads = false;
bool msl_dynamic_vertex_stride = false;
bool msl_force_sample_rate_shading = false;
bool msl_manual_helper_invocation_updates = true;
bool msl_check_discarded_frag_stores = false;
Expand All @@ -697,6 +700,8 @@ struct CLIArguments
SmallVector<pair<uint32_t, uint32_t>> msl_inline_uniform_blocks;
SmallVector<MSLShaderInterfaceVariable> msl_shader_inputs;
SmallVector<MSLShaderInterfaceVariable> msl_shader_outputs;
SmallVector<MSLVertexBinding> msl_vertex_bindings;
SmallVector<MSLVertexAttribute> msl_vertex_attributes;
SmallVector<PLSArg> pls_in;
SmallVector<PLSArg> pls_out;
SmallVector<Remap> remaps;
Expand Down Expand Up @@ -919,6 +924,8 @@ static void print_help_msl()
"\t\t<format> can be 'any32', 'any16', 'u16', 'u8', or 'other', to indicate a 32-bit opaque value, 16-bit opaque value, 16-bit unsigned integer, 8-bit unsigned integer, "
"or other-typed variable. <size> is the vector length of the variable, which must be greater than or equal to that declared in the shader."
"\t\tEquivalent to --msl-add-shader-output with a rate of 'vertex'.\n"
"\t[--msl-vertex-binding <buffer> <stride> (vertex|instance) <divisor>]:\n\t\tAdd a vertex buffer for the shader vertex loader\n"
"\t[--msl-vertex-attribute <location> <binding> <VkFormat> <offset>]:\n\t\tAdd a vertex attribute for the shader vertex loader\n"
"\t[--msl-raw-buffer-tese-input]:\n\t\tUse raw buffers for tessellation evaluation input.\n"
"\t\tThis allows the use of nested structures and arrays.\n"
"\t\tIn a future version of SPIRV-Cross, this will become the default.\n"
Expand All @@ -945,6 +952,8 @@ static void print_help_msl()
"\t[--msl-fixed-subgroup-size <size>]:\n\t\tAssign a constant <size> to the SubgroupSize builtin.\n"
"\t\tIntended for Vulkan Portability implementations where VK_EXT_subgroup_size_control is not supported or disabled.\n"
"\t\tIf 0, assume variable subgroup size as actually exposed by Metal.\n"
"\t[--msl-use-pixel-type-loads]:\n\t\tEnable use of MSL pixel-type loads (e.g. rgb9e5<float3>).\n"
"\t[--msl-dynamic-vertex-stride]:\n\t\tEnable dynamic strides in shader vertex loader.\n"
"\t[--msl-force-sample-rate-shading]:\n\t\tForce fragment shaders to run per sample.\n"
"\t\tThis adds a [[sample_id]] parameter if none is already present.\n"
"\t[--msl-no-manual-helper-invocation-updates]:\n\t\tDo not manually update the HelperInvocation builtin when a fragment is discarded.\n"
Expand Down Expand Up @@ -1229,6 +1238,9 @@ static string compile_iteration(const CLIArguments &args, std::vector<uint32_t>
msl_opts.ios_use_simdgroup_functions = args.msl_ios_use_simdgroup_functions;
msl_opts.emulate_subgroups = args.msl_emulate_subgroups;
msl_opts.fixed_subgroup_size = args.msl_fixed_subgroup_size;
msl_opts.vertex_index_type = args.msl_vertex_index_type;
msl_opts.vertex_loader_dynamic_stride = args.msl_dynamic_vertex_stride;
msl_opts.use_pixel_type_loads = args.msl_use_pixel_type_loads;
msl_opts.force_sample_rate_shading = args.msl_force_sample_rate_shading;
msl_opts.manual_helper_invocation_updates = args.msl_manual_helper_invocation_updates;
msl_opts.check_discarded_frag_stores = args.msl_check_discarded_frag_stores;
Expand All @@ -1237,6 +1249,10 @@ static string compile_iteration(const CLIArguments &args, std::vector<uint32_t>
msl_opts.runtime_array_rich_descriptor = args.msl_runtime_array_rich_descriptor;
msl_opts.replace_recursive_inputs = args.msl_replace_recursive_inputs;
msl_comp->set_msl_options(msl_opts);
for (auto &v : args.msl_vertex_bindings)
msl_comp->add_shader_vertex_loader_binding(v);
for (auto &v : args.msl_vertex_attributes)
msl_comp->add_shader_vertex_loader_attribute(v);
for (auto &v : args.msl_discrete_descriptor_sets)
msl_comp->add_discrete_descriptor_set(v);
for (auto &v : args.msl_device_argument_buffers)
Expand Down Expand Up @@ -1769,6 +1785,179 @@ static int main_inner(int argc, char *argv[])
output.vecsize = parser.next_uint();
args.msl_shader_outputs.push_back(output);
});
cbs.add("--msl-vertex-binding", [&args](CLIParser &parser) {
MSLVertexBinding binding;
binding.binding = parser.next_uint();
binding.stride = parser.next_uint();
std::string rate = parser.next_string();
std::transform(rate.begin(), rate.end(), rate.begin(), [](uint8_t c){ return tolower(c); });
if (0 == rate.compare("vertex"))
binding.rate = MSL_VERTEX_INPUT_RATE_VERTEX;
else if (0 == rate.compare("instance"))
binding.rate = MSL_VERTEX_INPUT_RATE_INSTANCE;
else
THROW("Bad vertex binding input rate");
binding.divisor = parser.next_uint();
args.msl_vertex_bindings.push_back(binding);
});
cbs.add("--msl-vertex-attribute", [&args](CLIParser &parser) {
MSLVertexAttribute attribute;
attribute.location = parser.next_uint();
attribute.binding = parser.next_uint();
std::string format = parser.next_string();
attribute.offset = parser.next_uint();
std::transform(format.begin(), format.end(), format.begin(), [](uint8_t c){ return toupper(c); });
size_t format_beg = 0;
if (0 == format.compare(0, strlen("MSL_FORMAT_"), "MSL_FORMAT_"))
format_beg = strlen("MSL_FORMAT_");
if (0 == format.compare(0, strlen("VK_FORMAT_"), "VK_FORMAT_"))
format_beg = strlen("VK_FORMAT_");
static constexpr struct {
MSLFormat format;
const char* name;
} formats[] = {
#define FORMAT(x) { MSL_FORMAT_##x, #x }
FORMAT(R4G4_UNORM_PACK8),
FORMAT(R4G4B4A4_UNORM_PACK16),
FORMAT(B4G4R4A4_UNORM_PACK16),
FORMAT(R5G6B5_UNORM_PACK16),
FORMAT(B5G6R5_UNORM_PACK16),
FORMAT(R5G5B5A1_UNORM_PACK16),
FORMAT(B5G5R5A1_UNORM_PACK16),
FORMAT(A1R5G5B5_UNORM_PACK16),
FORMAT(R8_UNORM),
FORMAT(R8_SNORM),
FORMAT(R8_USCALED),
FORMAT(R8_SSCALED),
FORMAT(R8_UINT),
FORMAT(R8_SINT),
FORMAT(R8_SRGB),
FORMAT(R8G8_UNORM),
FORMAT(R8G8_SNORM),
FORMAT(R8G8_USCALED),
FORMAT(R8G8_SSCALED),
FORMAT(R8G8_UINT),
FORMAT(R8G8_SINT),
FORMAT(R8G8_SRGB),
FORMAT(R8G8B8_UNORM),
FORMAT(R8G8B8_SNORM),
FORMAT(R8G8B8_USCALED),
FORMAT(R8G8B8_SSCALED),
FORMAT(R8G8B8_UINT),
FORMAT(R8G8B8_SINT),
FORMAT(R8G8B8_SRGB),
FORMAT(B8G8R8_UNORM),
FORMAT(B8G8R8_SNORM),
FORMAT(B8G8R8_USCALED),
FORMAT(B8G8R8_SSCALED),
FORMAT(B8G8R8_UINT),
FORMAT(B8G8R8_SINT),
FORMAT(B8G8R8_SRGB),
FORMAT(R8G8B8A8_UNORM),
FORMAT(R8G8B8A8_SNORM),
FORMAT(R8G8B8A8_USCALED),
FORMAT(R8G8B8A8_SSCALED),
FORMAT(R8G8B8A8_UINT),
FORMAT(R8G8B8A8_SINT),
FORMAT(R8G8B8A8_SRGB),
FORMAT(B8G8R8A8_UNORM),
FORMAT(B8G8R8A8_SNORM),
FORMAT(B8G8R8A8_USCALED),
FORMAT(B8G8R8A8_SSCALED),
FORMAT(B8G8R8A8_UINT),
FORMAT(B8G8R8A8_SINT),
FORMAT(B8G8R8A8_SRGB),
FORMAT(A8B8G8R8_UNORM_PACK32),
FORMAT(A8B8G8R8_SNORM_PACK32),
FORMAT(A8B8G8R8_USCALED_PACK32),
FORMAT(A8B8G8R8_SSCALED_PACK32),
FORMAT(A8B8G8R8_UINT_PACK32),
FORMAT(A8B8G8R8_SINT_PACK32),
FORMAT(A8B8G8R8_SRGB_PACK32),
FORMAT(A2R10G10B10_UNORM_PACK32),
FORMAT(A2R10G10B10_SNORM_PACK32),
FORMAT(A2R10G10B10_USCALED_PACK32),
FORMAT(A2R10G10B10_SSCALED_PACK32),
FORMAT(A2R10G10B10_UINT_PACK32),
FORMAT(A2R10G10B10_SINT_PACK32),
FORMAT(A2B10G10R10_UNORM_PACK32),
FORMAT(A2B10G10R10_SNORM_PACK32),
FORMAT(A2B10G10R10_USCALED_PACK32),
FORMAT(A2B10G10R10_SSCALED_PACK32),
FORMAT(A2B10G10R10_UINT_PACK32),
FORMAT(A2B10G10R10_SINT_PACK32),
FORMAT(R16_UNORM),
FORMAT(R16_SNORM),
FORMAT(R16_USCALED),
FORMAT(R16_SSCALED),
FORMAT(R16_UINT),
FORMAT(R16_SINT),
FORMAT(R16_SFLOAT),
FORMAT(R16G16_UNORM),
FORMAT(R16G16_SNORM),
FORMAT(R16G16_USCALED),
FORMAT(R16G16_SSCALED),
FORMAT(R16G16_UINT),
FORMAT(R16G16_SINT),
FORMAT(R16G16_SFLOAT),
FORMAT(R16G16B16_UNORM),
FORMAT(R16G16B16_SNORM),
FORMAT(R16G16B16_USCALED),
FORMAT(R16G16B16_SSCALED),
FORMAT(R16G16B16_UINT),
FORMAT(R16G16B16_SINT),
FORMAT(R16G16B16_SFLOAT),
FORMAT(R16G16B16A16_UNORM),
FORMAT(R16G16B16A16_SNORM),
FORMAT(R16G16B16A16_USCALED),
FORMAT(R16G16B16A16_SSCALED),
FORMAT(R16G16B16A16_UINT),
FORMAT(R16G16B16A16_SINT),
FORMAT(R16G16B16A16_SFLOAT),
FORMAT(R32_UINT),
FORMAT(R32_SINT),
FORMAT(R32_SFLOAT),
FORMAT(R32G32_UINT),
FORMAT(R32G32_SINT),
FORMAT(R32G32_SFLOAT),
FORMAT(R32G32B32_UINT),
FORMAT(R32G32B32_SINT),
FORMAT(R32G32B32_SFLOAT),
FORMAT(R32G32B32A32_UINT),
FORMAT(R32G32B32A32_SINT),
FORMAT(R32G32B32A32_SFLOAT),
FORMAT(R64_UINT),
FORMAT(R64_SINT),
FORMAT(R64_SFLOAT),
FORMAT(R64G64_UINT),
FORMAT(R64G64_SINT),
FORMAT(R64G64_SFLOAT),
FORMAT(R64G64B64_UINT),
FORMAT(R64G64B64_SINT),
FORMAT(R64G64B64_SFLOAT),
FORMAT(R64G64B64A64_UINT),
FORMAT(R64G64B64A64_SINT),
FORMAT(R64G64B64A64_SFLOAT),
FORMAT(B10G11R11_UFLOAT_PACK32),
FORMAT(E5B9G9R9_UFLOAT_PACK32),
FORMAT(G16B16G16R16_422_UNORM),
FORMAT(B16G16R16G16_422_UNORM),
FORMAT(A4R4G4B4_UNORM_PACK16),
FORMAT(A4B4G4R4_UNORM_PACK16),
FORMAT(R16G16_S10_5_NV),
FORMAT(A1B5G5R5_UNORM_PACK16_KHR),
FORMAT(A8_UNORM_KHR),
#undef FORMAT
};
size_t fmt_idx = 0;
for (; fmt_idx < sizeof(formats) / sizeof(*formats); fmt_idx++)
if (0 == format.compare(format_beg, format.npos, formats[fmt_idx].name))
break;
if (fmt_idx == sizeof(formats) / sizeof(*formats))
THROW("Bad vertex attribute format");
attribute.format = formats[fmt_idx].format;
args.msl_vertex_attributes.push_back(attribute);
});
cbs.add("--msl-raw-buffer-tese-input", [&args](CLIParser &) { args.msl_raw_buffer_tese_input = true; });
cbs.add("--msl-multi-patch-workgroup", [&args](CLIParser &) { args.msl_multi_patch_workgroup = true; });
cbs.add("--msl-vertex-for-tessellation", [&args](CLIParser &) { args.msl_vertex_for_tessellation = true; });
Expand All @@ -1784,6 +1973,20 @@ static int main_inner(int argc, char *argv[])
cbs.add("--msl-emulate-subgroups", [&args](CLIParser &) { args.msl_emulate_subgroups = true; });
cbs.add("--msl-fixed-subgroup-size",
[&args](CLIParser &parser) { args.msl_fixed_subgroup_size = parser.next_uint(); });
cbs.add("--msl-vertex-index-type", [&args](CLIParser &parser) {
std::string str = parser.next_string();
std::transform(str.begin(), str.end(), str.begin(), [](uint8_t c){ return tolower(c); });
if (str == "none")
args.msl_vertex_index_type = CompilerMSL::Options::IndexType::None;
else if (str == "uint16")
args.msl_vertex_index_type = CompilerMSL::Options::IndexType::UInt16;
else if (str == "uint32")
args.msl_vertex_index_type = CompilerMSL::Options::IndexType::UInt32;
else
THROW("Bad index type");
});
cbs.add("--msl-use-pixel-type-loads", [&args](CLIParser &) { args.msl_use_pixel_type_loads = true; });
cbs.add("--msl-dynamic-vertex-stride", [&args](CLIParser &) { args.msl_dynamic_vertex_stride = true; });
cbs.add("--msl-force-sample-rate-shading", [&args](CLIParser &) { args.msl_force_sample_rate_shading = true; });
cbs.add("--msl-no-manual-helper-invocation-updates",
[&args](CLIParser &) { args.msl_manual_helper_invocation_updates = false; });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ fragment main0_out main0(main0_in in [[stage_in]], uint gl_SampleID [[sample_id]
_13 s = {};
spvUnsafeArray<float2, 2> b = {};
spvUnsafeArray<float2, 2> c = {};
a[0] = in.a_0.interpolate_at_center();
a[1] = in.a_1.interpolate_at_center();
s.x = in.s_x.interpolate_at_center();
s.y = in.s_y.interpolate_at_centroid();
s.z = in.s_z.interpolate_at_sample(gl_SampleID);
Expand All @@ -102,6 +100,8 @@ fragment main0_out main0(main0_in in [[stage_in]], uint gl_SampleID [[sample_id]
s.w[0] = in.s_w_0.interpolate_at_center();
s.w[1] = in.s_w_1.interpolate_at_center();
s.w[2] = in.s_w_2.interpolate_at_center();
a[0] = in.a_0.interpolate_at_center();
a[1] = in.a_1.interpolate_at_center();
b[0] = in.b_0.interpolate_at_centroid();
b[1] = in.b_1.interpolate_at_centroid();
c[0] = in.c_0.interpolate_at_sample(gl_SampleID);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ kernel void main0(main0_in in [[stage_in]], uint3 gl_GlobalInvocationID [[thread
spvUnsafeArray<float, 3> InC = {};
float InD = {};
device main0_out& out = spvOut[gl_GlobalInvocationID.y * spvStageInputSize.x + gl_GlobalInvocationID.x];
if (any(gl_GlobalInvocationID >= spvStageInputSize))
return;
InA[0] = in.m_location_1.x;
InA[1] = in.m_location_2.x;
InB[0] = in.m_location_1.zw;
Expand All @@ -79,8 +81,6 @@ kernel void main0(main0_in in [[stage_in]], uint3 gl_GlobalInvocationID [[thread
InC[1] = in.m_location_1.y;
InC[2] = in.m_location_2.y;
InD = in.m_location_0.w;
if (any(gl_GlobalInvocationID >= spvStageInputSize))
return;
out.gl_Position = in.Pos;
A = InA;
B = InB;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

static half3 spvLoadVertexRG11B10Half(uint value)
{
ushort3 res = ushort3((value << 4) & 0x7ff0, (value >> 7) & 0x7ff0, (value >> 17) & 0x7fe0);
return as_type<half3>(res);
}
static float3 spvLoadVertexRGB9E5Float(uint value)
{
float exponent = exp2(float(value >> 27)) * exp2(float(-(15 + 9)));
uint3 mantissa = uint3(value & 0x1ff, extract_bits(value, 9, 9), extract_bits(value, 18, 9));
return float3(mantissa) * exponent;
}
struct main0_out
{
float4 gl_Position [[position]];
};

struct main0_in
{
float4 a0 [[attribute(0)]];
float4 a1 [[attribute(1)]];
float4 a3 [[attribute(3)]];
float4 a4 [[attribute(4)]];
float4 a5 [[attribute(5)]];
float4 a6 [[attribute(6)]];
uint a7 [[attribute(7)]];
float4 a8 [[attribute(8)]];
};

struct spvVertexData0
{
uchar4 a0;
uchar spvPad4;
packed_uchar3 a1;
};
static_assert(alignof(spvVertexData0) == 4, "Unexpected alignment");

struct spvVertexData1
{
ushort spvPad0[4];
packed_short4 a3;
};
static_assert(alignof(spvVertexData1) == 2, "Unexpected alignment");

struct spvVertexData2
{
uint a4;
uint a5;
uint a6;
};
static_assert(alignof(spvVertexData2) == 4, "Unexpected alignment");

struct spvVertexData3
{
uchar spvPad0;
uchar a8;
ushort spvPad2[15];
uint a7;
};
static_assert(alignof(spvVertexData3) == 4, "Unexpected alignment");

main0_in spvLoadVertex(const device spvVertexData0& data0, const device spvVertexData1& data1, const device spvVertexData2& data2, const device spvVertexData3& data3)
{
main0_in out;
out.a0 = unpack_unorm4x8_to_float(as_type<uint>(data0.a0));
out.a1 = float4(float3(uchar3(data0.a1)).bgr, 1);
out.a3 = max(float4(short4(data1.a3)) * (1.f / 32767), -1.f);
out.a4 = unpack_unorm10a2_to_float(data2.a4);
out.a5 = float4(spvLoadVertexRGB9E5Float(data2.a5), 1);
out.a6 = float4(float3(spvLoadVertexRG11B10Half(data2.a6)), 1);
out.a7 = data3.a7;
out.a8 = float4(float(data3.a8) * (1.f / 255), 0, 0, 1);
return out;
}

vertex main0_out main0(device const uchar* spvVertexBuffer0 [[buffer(0)]], device const uchar* spvVertexBuffer1 [[buffer(1)]], device const uchar* spvVertexBuffer2 [[buffer(2)]], device const uchar* spvVertexBuffer3 [[buffer(3)]], uint gl_VertexIndex [[vertex_id]], uint gl_BaseVertex [[base_vertex]], uint gl_InstanceIndex [[instance_id]], uint gl_BaseInstance [[base_instance]], const device uint* spvVertexStrides [[buffer(19)]])
{
main0_out out = {};
main0_in in = spvLoadVertex(*reinterpret_cast<device const spvVertexData0*>(spvVertexBuffer0 + spvVertexStrides[0] * gl_InstanceIndex),
*reinterpret_cast<device const spvVertexData1*>(spvVertexBuffer1 + spvVertexStrides[1] * gl_VertexIndex),
*reinterpret_cast<device const spvVertexData2*>(spvVertexBuffer2 + spvVertexStrides[2] * gl_BaseInstance),
*reinterpret_cast<device const spvVertexData3*>(spvVertexBuffer3 + spvVertexStrides[3] * (gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4)));
out.gl_Position = ((((((in.a0 + in.a1) + in.a3) + in.a4) + in.a5) + in.a6) + float4(float(in.a7))) + in.a8;
return out;
}

Loading