diff --git a/MetaballsKit/MarchingSquares.swift b/MetaballsKit/MarchingSquares.swift index 528aea8..c6eba7c 100644 --- a/MetaballsKit/MarchingSquares.swift +++ b/MetaballsKit/MarchingSquares.swift @@ -17,12 +17,13 @@ class MarchingSquares { private var semaphore: DispatchSemaphore private var samplingPipeline: MTLComputePipelineState? + private var contouringPipeline: MTLComputePipelineState? private var parametersBuffer: MTLBuffer? /// Samples of the field's current state. private(set) var samplesBuffer: MTLBuffer? /// Indexes of geometry to render. - private(set) var indexBuffer: MTLBuffer? + private(set) var contourIndexesBuffer: MTLBuffer? private(set) var gridGeometry: MTLBuffer? @@ -40,28 +41,39 @@ class MarchingSquares { return xSamples * ySamples } + var contourIndexesCount: Int { + return (xSamples - 1) * (ySamples - 1) + } + init(field: Field) { self.field = field semaphore = DispatchSemaphore(value: 1) } func setupMetal(withDevice device: MTLDevice, library: MTLLibrary) { - guard let samplingFunction = library.makeFunction(name: "samplingKernel") else { - fatalError("Couldn't get samplingKernel function from library") - } - do { - samplingPipeline = try device.makeComputePipelineState(function: samplingFunction) - } catch let e { - fatalError("Error building compute pipeline state for sampling kernel: \(e)") - } - + samplingPipeline = createComputePipeline(withFunctionNamed: "samplingKernel", device: device, library: library) + contouringPipeline = createComputePipeline(withFunctionNamed: "contouringKernel", device: device, library: library) createParametersBuffer(withDevice: device) createSamplesBuffer(withDevice: device) + createContourIndexesBuffer(withDevice: device) + } + + func createComputePipeline(withFunctionNamed functionName: String, device: MTLDevice, library: MTLLibrary) -> MTLComputePipelineState? { + guard let function = library.makeFunction(name: functionName) else { + print("Couldn't get comput function \"\(functionName)\" from library") + return nil + } + do { + return try device.makeComputePipelineState(function: function) + } catch let e { + print("Error building compute pipeline state: \(e)") + return nil + } } func createParametersBuffer(withDevice device: MTLDevice) { // TODO: I'm cheating on this cause I didn't want to make a parallel struct in Swift and deal with alignment crap. >_> I should make a real struct for this. - let parametersLength = MemoryLayout.stride * 3 + MemoryLayout.stride + let parametersLength = MemoryLayout.stride * 3 + MemoryLayout.stride parametersBuffer = device.makeBuffer(length: parametersLength, options: .storageModeShared) } @@ -77,6 +89,18 @@ class MarchingSquares { } } + func createContourIndexesBuffer(withDevice device: MTLDevice) { + // Only reallocate the buffer if the length changed. + let length = MemoryLayout.stride * contourIndexesCount + guard contourIndexesBuffer?.length != length else { + return + } + contourIndexesBuffer = device.makeBuffer(length: length, options: .storageModePrivate) + if contourIndexesBuffer == nil { + fatalError("Couldn't create contourIndexesBuffer!") + } + } + func fieldDidResize() { // Please just get the device from somewhere. 😅 guard let device = gridGeometry?.device ?? samplesBuffer?.device else { @@ -142,6 +166,7 @@ class MarchingSquares { print("Couldn't create compute encoder") return } + encoder.label = "Sample Field" encoder.setComputePipelineState(samplingPipeline) encoder.setBuffer(parametersBuffer, offset: 0, index: 0) @@ -155,4 +180,28 @@ class MarchingSquares { encoder.endEncoding() } + + func encodeContouringKernel(intoBuffer buffer: MTLCommandBuffer) { + guard let pipeline = contouringPipeline else { + print("Encode called before contouring pipeline was set up!") + return + } + guard let encoder = buffer.makeComputeCommandEncoder() else { + print("Couldn't create compute encoder") + return + } + + encoder.label = "Contouring" + encoder.setComputePipelineState(pipeline) + encoder.setBuffer(parametersBuffer, offset: 0, index: 0) + encoder.setBuffer(samplesBuffer, offset: 0, index: 1) + encoder.setBuffer(contourIndexesBuffer, offset: 0, index: 2) + + // Dispatch! + let gridSize = MTLSize(width: contourIndexesCount, height: 1, depth: 1) + let threadgroupSize = MTLSize(width: xSamples - 1, height: 1, depth: 1) + encoder.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize) + + encoder.endEncoding() + } } diff --git a/MetaballsKit/Renderer.swift b/MetaballsKit/Renderer.swift index 5377b5b..9cd65ea 100644 --- a/MetaballsKit/Renderer.swift +++ b/MetaballsKit/Renderer.swift @@ -226,6 +226,7 @@ public class Renderer: NSObject, MTKViewDelegate { if let marchingSquares = delegate?.marchingSquares { // Compute samples first. marchingSquares.encodeSamplingKernel(intoBuffer: buffer) + marchingSquares.encodeContouringKernel(intoBuffer: buffer) // Render the marching squares version over top of the pixel version. // We need our own render pass descriptor that specifies that we load the results of the previous pass to make this render pass appear on top of the other. @@ -238,7 +239,7 @@ public class Renderer: NSObject, MTKViewDelegate { encoder.setVertexBytes(Rect.geometry, length: MemoryLayout.stride * Rect.geometry.count, index: 0) encoder.setVertexBuffer(marchingSquares.gridGeometry, offset: 0, index: 1) encoder.setVertexBuffer(parametersBuffer, offset: 0, index: 2) - encoder.setFragmentBuffer(marchingSquares.samplesBuffer, offset: 0, index: 0) + encoder.setFragmentBuffer(marchingSquares.contourIndexesBuffer, offset: 0, index: 0) encoder.drawPrimitives(type: .triangle, vertexStart: 0, vertexCount: Rect.geometry.count, instanceCount: marchingSquares.samplesCount) encoder.endEncoding() didEncode = true diff --git a/MetaballsKit/Shaders/MarchingSquares.metal b/MetaballsKit/Shaders/MarchingSquares.metal index 1a8b9a3..59af943 100644 --- a/MetaballsKit/Shaders/MarchingSquares.metal +++ b/MetaballsKit/Shaders/MarchingSquares.metal @@ -64,6 +64,28 @@ samplingKernel(constant MarchingSquaresParameters ¶meters [[buffer(0)]], samples[idx] = sample; } +kernel void +contouringKernel(constant MarchingSquaresParameters ¶meters [[buffer(0)]], + constant float *samples [[buffer(1)]], + device ushort *contourIndexes [[buffer(2)]], + uint position [[thread_position_in_grid]]) +{ + // Calculate an index based on the samples at the four points around this cell. + // If the point is above the threshold, adjust the value accordingly. + // d--c 8--4 + // | | -> | | + // a--b 1--2 + uint a = position + parameters.gridSize.x; + uint b = position + parameters.gridSize.x + 1; + uint c = position + 1; + uint d = position; + uint index = (samples[d] >= 1.0 ? 0b1000 : 0) + + (samples[c] >= 1.0 ? 0b0100 : 0) + + (samples[b] >= 1.0 ? 0b0010 : 0) + + (samples[a] >= 1.0 ? 0b0001 : 0); + contourIndexes[position] = index; +} + vertex RasterizerData gridVertexShader(constant Vertex *vertexes [[buffer(0)]], constant Rect *rects [[buffer(1)]], @@ -85,9 +107,9 @@ gridVertexShader(constant Vertex *vertexes [[buffer(0)]], fragment float4 gridFragmentShader(RasterizerData in [[stage_in]], - constant float *samples [[buffer(0)]]) + constant ushort *contourIndexes [[buffer(0)]]) { int instance = in.instance; - float sample = samples[instance]; - return sample > 1.0 ? in.color : float4(0); + uint sample = contourIndexes[instance]; + return sample >= 1 ? in.color : float4(0); }