diff --git a/MetaballsKit/Metaballs.swift b/MetaballsKit/Metaballs.swift index d9ae4a1..247844a 100644 --- a/MetaballsKit/Metaballs.swift +++ b/MetaballsKit/Metaballs.swift @@ -38,6 +38,7 @@ public class Field { didSet { // Remove balls that fall outside the new bounds. balls = balls.filter { bounds.contains($0.bounds) } + updateThreadgroupSizes(withFieldSize: size) } } @@ -82,9 +83,24 @@ public class Field { // MARK: - Metal Configuration + private var device: MTLDevice? + private var sampleComputeState: MTLComputePipelineState? + private var parametersBuffer: MTLBuffer? private var ballBuffer: MTLBuffer? private(set) var sampleTexture: MTLTexture? + private var threadgroupCount = MTLSize() + // TODO: It might be possible to (more dynamically) right-size this. + private var threadgroupSize = MTLSize(width: 16, height: 16, depth: 1) + + /// Create the Metal buffer containing basic parameters of the simulation. + private func makeParametersBufferIfNeeded(withDevice device: MTLDevice) -> MTLBuffer? { + if parametersBuffer == nil { + parametersBuffer = device.makeBuffer(length: MemoryLayout.size * 3, options: []) + } + return parametersBuffer + } + /// Create a Metal buffer containing the current set of metaballs. /// @param device The Metal device to use to create the buffer. /// @return A new buffer containing metaball data. @@ -116,30 +132,84 @@ public class Field { return sampleTexture } - public func computePipelineStateForSamplingKernel(withDevice device: MTLDevice) throws -> MTLComputePipelineState? { - let library = device.newDefaultLibrary() - if let samplingKernel = library?.makeFunction(name: "samplingKernel") { - let computePipelineState = try device.makeComputePipelineState(function: samplingKernel) - return computePipelineState + /// Update the threadgroup divisions based on the size of the field. + /// @param size The size of the field. + private func updateThreadgroupSizes(withFieldSize size: CGSize) { + let width = Int(size.width) + let height = Int(size.height) + threadgroupCount = MTLSize(width: width + threadgroupSize.width - 1, height: height + threadgroupSize.height - 1, depth: 1) + } + + /// Copy metaballs data into the parameters buffer. + private func updateParametersBuffer() { + guard let parameters = parametersBuffer else { + return } - else { - return nil + + var ptr = parameters.contents() + let sizeOfInt = MemoryLayout.size + + var width = Int(size.width) + ptr.copyBytes(from: &width, count: sizeOfInt) + ptr = ptr.advanced(by: sizeOfInt) + + var height = Int(size.height) + ptr.copyBytes(from: &height, count: sizeOfInt) + ptr = ptr.advanced(by: sizeOfInt) + + var numberOfBalls = balls.count + ptr.copyBytes(from: &numberOfBalls, count: sizeOfInt) + } + + public func setupMetal(withDevice device: MTLDevice) throws { + guard self.device == nil else { + return + } + self.device = device + do { + sampleComputeState = try computePipelineStateForSamplingKernel(withDevice: device) + } + catch let e { + throw e } } - public func computeEncoderForSamplingKernel(withDevice device: MTLDevice, commandBuffer buffer: MTLCommandBuffer, state: MTLComputePipelineState) throws -> MTLComputeCommandEncoder { - guard let ballBuffer = makeBallBufferIfNeeded(withDevice: device), - let sampleTexture = makeSampleTextureIfNeeded(withDevice: device) else { + public func computePipelineStateForSamplingKernel(withDevice device: MTLDevice) throws -> MTLComputePipelineState? { + do { + guard let samplingKernelLibraryPath = Bundle.main.path(forResource: "SamplingKernel", ofType: "metal") else { + return nil + } + let library = try device.makeLibrary(filepath: samplingKernelLibraryPath) + guard let samplingKernel = library.makeFunction(name: "samplingKernel") else { + return nil + } + let state = try device.makeComputePipelineState(function: samplingKernel) + return state + } + catch let e { + throw e + } + } + + public func computeEncoderForSamplingKernel(withDevice device: MTLDevice, commandBuffer buffer: MTLCommandBuffer) throws -> MTLComputeCommandEncoder { + guard let parametersBuffer = makeParametersBufferIfNeeded(withDevice: device), + let ballBuffer = makeBallBufferIfNeeded(withDevice: device), + let sampleTexture = makeSampleTextureIfNeeded(withDevice: device), + let state = sampleComputeState + else { throw MetaballsError.metalError } + let encoder = buffer.makeComputeCommandEncoder() encoder.setComputePipelineState(state) - encoder.setBuffer(ballBuffer, offset: 0, at: 0) + encoder.setBuffer(parametersBuffer, offset: 0, at: 0) + encoder.setBuffer(ballBuffer, offset: 0, at: 1) encoder.setTexture(sampleTexture, at: 0) - // TODO: Decide on actual values for these - let threadgroupsPerGrid = MTLSize() - let threadsPerThreadgroup = MTLSize() - encoder.dispatchThreadgroups(threadgroupsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup) + encoder.dispatchThreadgroups(threadgroupCount, threadsPerThreadgroup: threadgroupSize) + encoder.endEncoding() + + updateParametersBuffer() + return encoder } } diff --git a/MetaballsKit/SamplingKernel.metal b/MetaballsKit/SamplingKernel.metal index a8b09ca..753d542 100644 --- a/MetaballsKit/SamplingKernel.metal +++ b/MetaballsKit/SamplingKernel.metal @@ -9,6 +9,12 @@ #include using namespace metal; +typedef struct { + int2 size; + int numberOfBalls; +} Parameters; + +// TODO: This is a dupe of the Ball struct. Is there a way to DRY this? typedef struct { float radius; float2 position; @@ -16,14 +22,15 @@ typedef struct { } Ball; kernel void -sampleFieldKernel(constant Ball* balls [[buffer(0)]], - texture2d samples [[texture(1)]], +sampleFieldKernel(constant Parameters& parameters [[buffer(0)]], + constant Ball* balls [[buffer(1)]], + texture2d samples [[texture(0)]], uint2 gid [[thread_position_in_grid]]) { float sample = 0.0; // TODO: Get number of metaballs. for (int i = 0; i < 2; i++) { - constant Ball& ball = metaballs[i]; + constant Ball& ball = balls[i]; float r2 = ball.radius * ball.radius; float xDiff = gid[0] - ball.position[0]; float yDiff = gid[1] - ball.position[1];