[MetaballsKit] Finish off Metal compute stuff for the simulation
This commit is contained in:
parent
467f9d4123
commit
da5c664ee6
2 changed files with 95 additions and 18 deletions
|
|
@ -38,6 +38,7 @@ public class Field {
|
||||||
didSet {
|
didSet {
|
||||||
// Remove balls that fall outside the new bounds.
|
// Remove balls that fall outside the new bounds.
|
||||||
balls = balls.filter { bounds.contains($0.bounds) }
|
balls = balls.filter { bounds.contains($0.bounds) }
|
||||||
|
updateThreadgroupSizes(withFieldSize: size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -82,9 +83,24 @@ public class Field {
|
||||||
|
|
||||||
// MARK: - Metal Configuration
|
// MARK: - Metal Configuration
|
||||||
|
|
||||||
|
private var device: MTLDevice?
|
||||||
|
private var sampleComputeState: MTLComputePipelineState?
|
||||||
|
private var parametersBuffer: MTLBuffer?
|
||||||
private var ballBuffer: MTLBuffer?
|
private var ballBuffer: MTLBuffer?
|
||||||
private(set) var sampleTexture: MTLTexture?
|
private(set) var sampleTexture: MTLTexture?
|
||||||
|
|
||||||
|
private var threadgroupCount = MTLSize()
|
||||||
|
// TODO: It might be possible to (more dynamically) right-size this.
|
||||||
|
private var threadgroupSize = MTLSize(width: 16, height: 16, depth: 1)
|
||||||
|
|
||||||
|
/// Create the Metal buffer containing basic parameters of the simulation.
|
||||||
|
private func makeParametersBufferIfNeeded(withDevice device: MTLDevice) -> MTLBuffer? {
|
||||||
|
if parametersBuffer == nil {
|
||||||
|
parametersBuffer = device.makeBuffer(length: MemoryLayout<Int>.size * 3, options: [])
|
||||||
|
}
|
||||||
|
return parametersBuffer
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a Metal buffer containing the current set of metaballs.
|
/// Create a Metal buffer containing the current set of metaballs.
|
||||||
/// @param device The Metal device to use to create the buffer.
|
/// @param device The Metal device to use to create the buffer.
|
||||||
/// @return A new buffer containing metaball data.
|
/// @return A new buffer containing metaball data.
|
||||||
|
|
@ -116,30 +132,84 @@ public class Field {
|
||||||
return sampleTexture
|
return sampleTexture
|
||||||
}
|
}
|
||||||
|
|
||||||
public func computePipelineStateForSamplingKernel(withDevice device: MTLDevice) throws -> MTLComputePipelineState? {
|
/// Update the threadgroup divisions based on the size of the field.
|
||||||
let library = device.newDefaultLibrary()
|
/// @param size The size of the field.
|
||||||
if let samplingKernel = library?.makeFunction(name: "samplingKernel") {
|
private func updateThreadgroupSizes(withFieldSize size: CGSize) {
|
||||||
let computePipelineState = try device.makeComputePipelineState(function: samplingKernel)
|
let width = Int(size.width)
|
||||||
return computePipelineState
|
let height = Int(size.height)
|
||||||
|
threadgroupCount = MTLSize(width: width + threadgroupSize.width - 1, height: height + threadgroupSize.height - 1, depth: 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copy metaballs data into the parameters buffer.
|
||||||
|
private func updateParametersBuffer() {
|
||||||
|
guard let parameters = parametersBuffer else {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
return nil
|
var ptr = parameters.contents()
|
||||||
|
let sizeOfInt = MemoryLayout<Int>.size
|
||||||
|
|
||||||
|
var width = Int(size.width)
|
||||||
|
ptr.copyBytes(from: &width, count: sizeOfInt)
|
||||||
|
ptr = ptr.advanced(by: sizeOfInt)
|
||||||
|
|
||||||
|
var height = Int(size.height)
|
||||||
|
ptr.copyBytes(from: &height, count: sizeOfInt)
|
||||||
|
ptr = ptr.advanced(by: sizeOfInt)
|
||||||
|
|
||||||
|
var numberOfBalls = balls.count
|
||||||
|
ptr.copyBytes(from: &numberOfBalls, count: sizeOfInt)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func setupMetal(withDevice device: MTLDevice) throws {
|
||||||
|
guard self.device == nil else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
self.device = device
|
||||||
|
do {
|
||||||
|
sampleComputeState = try computePipelineStateForSamplingKernel(withDevice: device)
|
||||||
|
}
|
||||||
|
catch let e {
|
||||||
|
throw e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public func computeEncoderForSamplingKernel(withDevice device: MTLDevice, commandBuffer buffer: MTLCommandBuffer, state: MTLComputePipelineState) throws -> MTLComputeCommandEncoder {
|
public func computePipelineStateForSamplingKernel(withDevice device: MTLDevice) throws -> MTLComputePipelineState? {
|
||||||
guard let ballBuffer = makeBallBufferIfNeeded(withDevice: device),
|
do {
|
||||||
let sampleTexture = makeSampleTextureIfNeeded(withDevice: device) else {
|
guard let samplingKernelLibraryPath = Bundle.main.path(forResource: "SamplingKernel", ofType: "metal") else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
let library = try device.makeLibrary(filepath: samplingKernelLibraryPath)
|
||||||
|
guard let samplingKernel = library.makeFunction(name: "samplingKernel") else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
let state = try device.makeComputePipelineState(function: samplingKernel)
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
catch let e {
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public func computeEncoderForSamplingKernel(withDevice device: MTLDevice, commandBuffer buffer: MTLCommandBuffer) throws -> MTLComputeCommandEncoder {
|
||||||
|
guard let parametersBuffer = makeParametersBufferIfNeeded(withDevice: device),
|
||||||
|
let ballBuffer = makeBallBufferIfNeeded(withDevice: device),
|
||||||
|
let sampleTexture = makeSampleTextureIfNeeded(withDevice: device),
|
||||||
|
let state = sampleComputeState
|
||||||
|
else {
|
||||||
throw MetaballsError.metalError
|
throw MetaballsError.metalError
|
||||||
}
|
}
|
||||||
|
|
||||||
let encoder = buffer.makeComputeCommandEncoder()
|
let encoder = buffer.makeComputeCommandEncoder()
|
||||||
encoder.setComputePipelineState(state)
|
encoder.setComputePipelineState(state)
|
||||||
encoder.setBuffer(ballBuffer, offset: 0, at: 0)
|
encoder.setBuffer(parametersBuffer, offset: 0, at: 0)
|
||||||
|
encoder.setBuffer(ballBuffer, offset: 0, at: 1)
|
||||||
encoder.setTexture(sampleTexture, at: 0)
|
encoder.setTexture(sampleTexture, at: 0)
|
||||||
// TODO: Decide on actual values for these
|
encoder.dispatchThreadgroups(threadgroupCount, threadsPerThreadgroup: threadgroupSize)
|
||||||
let threadgroupsPerGrid = MTLSize()
|
encoder.endEncoding()
|
||||||
let threadsPerThreadgroup = MTLSize()
|
|
||||||
encoder.dispatchThreadgroups(threadgroupsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup)
|
updateParametersBuffer()
|
||||||
|
|
||||||
return encoder
|
return encoder
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,12 @@
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int2 size;
|
||||||
|
int numberOfBalls;
|
||||||
|
} Parameters;
|
||||||
|
|
||||||
|
// TODO: This is a dupe of the Ball struct. Is there a way to DRY this?
|
||||||
typedef struct {
|
typedef struct {
|
||||||
float radius;
|
float radius;
|
||||||
float2 position;
|
float2 position;
|
||||||
|
|
@ -16,14 +22,15 @@ typedef struct {
|
||||||
} Ball;
|
} Ball;
|
||||||
|
|
||||||
kernel void
|
kernel void
|
||||||
sampleFieldKernel(constant Ball* balls [[buffer(0)]],
|
sampleFieldKernel(constant Parameters& parameters [[buffer(0)]],
|
||||||
texture2d<half, access::write> samples [[texture(1)]],
|
constant Ball* balls [[buffer(1)]],
|
||||||
|
texture2d<half, access::write> samples [[texture(0)]],
|
||||||
uint2 gid [[thread_position_in_grid]])
|
uint2 gid [[thread_position_in_grid]])
|
||||||
{
|
{
|
||||||
float sample = 0.0;
|
float sample = 0.0;
|
||||||
// TODO: Get number of metaballs.
|
// TODO: Get number of metaballs.
|
||||||
for (int i = 0; i < 2; i++) {
|
for (int i = 0; i < 2; i++) {
|
||||||
constant Ball& ball = metaballs[i];
|
constant Ball& ball = balls[i];
|
||||||
float r2 = ball.radius * ball.radius;
|
float r2 = ball.radius * ball.radius;
|
||||||
float xDiff = gid[0] - ball.position[0];
|
float xDiff = gid[0] - ball.position[0];
|
||||||
float yDiff = gid[1] - ball.position[1];
|
float yDiff = gid[1] - ball.position[1];
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue