diff --git a/MetaballsKit/Metaballs.swift b/MetaballsKit/Metaballs.swift index 8355307..d9ae4a1 100644 --- a/MetaballsKit/Metaballs.swift +++ b/MetaballsKit/Metaballs.swift @@ -10,7 +10,7 @@ import Foundation import MetalKit public enum MetaballsError: Error { - case couldntAddBall + case metalError } public struct Ball { @@ -75,25 +75,27 @@ public class Field { } } - public func add(ball: Ball) throws { - guard bounds.contains(ball.bounds) else { - throw MetaballsError.couldntAddBall - } + public func add(ball: Ball) { + guard bounds.contains(ball.bounds) else { return } balls.append(ball) } // MARK: - Metal Configuration + private var ballBuffer: MTLBuffer? + private(set) var sampleTexture: MTLTexture? + /// Create a Metal buffer containing the current set of metaballs. /// @param device The Metal device to use to create the buffer. /// @return A new buffer containing metaball data. - public func makeBallBuffer(withDevice device: MTLDevice) -> MTLBuffer? { - let sizeOfBall = MemoryLayout.size - let length = balls.count * sizeOfBall - var ballBuffer: MTLBuffer? = nil - balls.withUnsafeMutableBytes { (buffer: UnsafeMutableRawBufferPointer) in - if let bytes = buffer.baseAddress { - ballBuffer = device.makeBuffer(bytesNoCopy: bytes, length: length, options: [], deallocator: nil) + private func makeBallBufferIfNeeded(withDevice device: MTLDevice) -> MTLBuffer? { + if ballBuffer == nil { + let sizeOfBall = MemoryLayout.size + let length = balls.count * sizeOfBall + balls.withUnsafeMutableBytes { (buffer: UnsafeMutableRawBufferPointer) in + if let bytes = buffer.baseAddress { + ballBuffer = device.makeBuffer(bytesNoCopy: bytes, length: length, options: [], deallocator: nil) + } } } return ballBuffer @@ -102,14 +104,16 @@ public class Field { /// Create a Metal texture to hold sample values created by the sampling compute shader. /// @param device The Metal device to use to create the texture. /// @return A new texture. - public func makeSampleTexture(withDevice device: MTLDevice) -> MTLTexture? { - let desc = MTLTextureDescriptor() - desc.pixelFormat = .r16Float - desc.width = Int(size.width) - desc.height = Int(size.height) - desc.usage = .shaderWrite - let texture = device.makeTexture(descriptor: desc) - return texture + private func makeSampleTextureIfNeeded(withDevice device: MTLDevice) -> MTLTexture? { + if sampleTexture == nil { + let desc = MTLTextureDescriptor() + desc.pixelFormat = .r16Float + desc.width = Int(size.width) + desc.height = Int(size.height) + desc.usage = .shaderWrite + sampleTexture = device.makeTexture(descriptor: desc) + } + return sampleTexture } public func computePipelineStateForSamplingKernel(withDevice device: MTLDevice) throws -> MTLComputePipelineState? { @@ -123,11 +127,15 @@ public class Field { } } - public func computeEncoderForSamplingKernel(withCommandBuffer buffer: MTLCommandBuffer, state: MTLComputePipelineState, balls: MTLBuffer, samples: MTLTexture) -> MTLComputeCommandEncoder { + public func computeEncoderForSamplingKernel(withDevice device: MTLDevice, commandBuffer buffer: MTLCommandBuffer, state: MTLComputePipelineState) throws -> MTLComputeCommandEncoder { + guard let ballBuffer = makeBallBufferIfNeeded(withDevice: device), + let sampleTexture = makeSampleTextureIfNeeded(withDevice: device) else { + throw MetaballsError.metalError + } let encoder = buffer.makeComputeCommandEncoder() encoder.setComputePipelineState(state) - encoder.setBuffer(balls, offset: 0, at: 0) - encoder.setTexture(samples, at: 0) + encoder.setBuffer(ballBuffer, offset: 0, at: 0) + encoder.setTexture(sampleTexture, at: 0) // TODO: Decide on actual values for these let threadgroupsPerGrid = MTLSize() let threadsPerThreadgroup = MTLSize()