diff --git a/ep0099/ep0099.go b/ep0099/ep0099.go index a58fb15..e7fe7aa 100644 --- a/ep0099/ep0099.go +++ b/ep0099/ep0099.go @@ -22,7 +22,7 @@ const ( type Dev struct { i2c i2c.Dev - state map[uint8]State + state [4]State } func New(bus i2c.Bus, address uint16) (*Dev, error) { @@ -32,7 +32,7 @@ func New(bus i2c.Bus, address uint16) (*Dev, error) { d := &Dev{ i2c: i2c.Dev{Bus: bus, Addr: address}, - state: buildChannelsList(), + state: [4]State{StateOff, StateOff, StateOff, StateOff}, } if err := d.reset(); err != nil { @@ -47,30 +47,30 @@ func (d *Dev) Halt() error { } func (d *Dev) On(channel uint8) error { - if err := d.isValidChannel(channel); err != nil { - return err + if !d.isValidChannel(channel) { + return errInvalidChannel } _, err := d.i2c.Write([]byte{channel, byte(StateOn)}) - d.state[channel] = StateOn + d.state[channel-1] = StateOn return err } func (d *Dev) Off(channel uint8) error { - if err := d.isValidChannel(channel); err != nil { - return err + if !d.isValidChannel(channel) { + return errInvalidChannel } _, err := d.i2c.Write([]byte{channel, byte(StateOff)}) - d.state[channel] = StateOff + d.state[channel-1] = StateOff return err } func (d *Dev) State(channel uint8) (State, error) { - if err := d.isValidChannel(channel); err != nil { - return 0, err + if !d.isValidChannel(channel) { + return 0, errInvalidChannel } - return d.state[channel], nil + return d.state[channel-1], nil } func (d *Dev) AvailableChannels() []uint8 { @@ -84,9 +84,8 @@ func (s State) String() string { return "on" } -// Reset resets the registers to the default values. func (d *Dev) reset() error { - for channel := range d.state { + for _, channel := range d.AvailableChannels() { err := d.Off(channel) if err != nil { return err @@ -99,34 +98,14 @@ func (d *Dev) reset() error { // Up to 4 HATs can be stacked and each one need a different address to // work. func isValidAddress(address uint16) error { - validAddresses := [...]uint16{0x10, 0x11, 0x12, 0x13} - - for _, addr := range validAddresses { - if address == addr { - return nil - } + switch address { + case 0x10, 0x11, 0x12, 0x13: + return nil + default: + return errInvalidAddress } - - return errInvalidAddress } -func (d *Dev) isValidChannel(channel uint8) error { - if _, exists := d.state[channel]; !exists { - return errInvalidChannel - } - return nil -} - -// EP-0099 offers 4 channels per board -func buildChannelsList() map[uint8]State { - // Using a map instead of list since indexes of channels are not zero-based - // values. That would cause loops to have to correct channel ids while - // looping through items or reading/setting values. - // With maps, keys correspond to actual channels on the board. - return map[uint8]State{ - 1: StateOff, - 2: StateOff, - 3: StateOff, - 4: StateOff, - } +func (d *Dev) isValidChannel(channel uint8) bool { + return channel >= 1 && channel <= 4 } diff --git a/ep0099/ep0099_test.go b/ep0099/ep0099_test.go index d9a999d..5f14442 100644 --- a/ep0099/ep0099_test.go +++ b/ep0099/ep0099_test.go @@ -8,6 +8,7 @@ import ( "bytes" "errors" "fmt" + "reflect" "testing" "periph.io/x/conn/v3/i2c/i2ctest" @@ -46,16 +47,10 @@ func TestAvailableChannels(t *testing.T) { expected := []uint8{0x01, 0x02, 0x03, 0x04} dev, _ := New(bus, testDefaultValidAddress) - channels := dev.AvailableChannels() + list := dev.AvailableChannels() - if len(channels) != len(expected) { - t.Fatal("Available channels len should be ", len(expected), ", got ", len(channels)) - } - - for i := 0; i < len(expected); i++ { - if channels[i] != expected[i] { - t.Fatal("Available channels should be ", expected, " got ", channels) - } + if !reflect.DeepEqual(expected, list) { + t.Fatal("Available channels should be ", expected, " got ", list) } } @@ -63,8 +58,6 @@ func TestHalt(t *testing.T) { bus := initTestBus() dev, _ := New(bus, testDefaultValidAddress) - resetTestBusOps(bus) - dev.Halt() checkDevReset(t, dev, bus) } @@ -73,8 +66,6 @@ func TestOn(t *testing.T) { bus := initTestBus() dev, _ := New(bus, testDefaultValidAddress) - resetTestBusOps(bus) - err := dev.On(3) if err != nil { @@ -89,8 +80,6 @@ func TestOff(t *testing.T) { bus := initTestBus() dev, _ := New(bus, testDefaultValidAddress) - resetTestBusOps(bus) - err := dev.Off(4) if err != nil { @@ -112,6 +101,20 @@ func TestReturnErrorForInvalidChannel(t *testing.T) { if err := dev.Off(98); err != errInvalidChannel { t.Fatal("Off should return invalid channel error, got ", err) } + + if err := dev.Off(98); err != errInvalidChannel { + t.Fatal("Off should return invalid channel error, got ", err) + } +} + +func TestStateToString(t *testing.T) { + if s := fmt.Sprintf("%s", StateOn); s != "on" { + t.Fatal("StateOn as string should be 'on', got ", s) + } + + if s := fmt.Sprintf("%s", StateOff); s != "off" { + t.Fatal("StateOn as string should be 'off', got ", s) + } } func initTestBus() *i2ctest.Record { @@ -121,10 +124,6 @@ func initTestBus() *i2ctest.Record { } } -func resetTestBusOps(bus *i2ctest.Record) { - bus.Ops = []i2ctest.IO{} -} - func checkChannelState(t *testing.T, dev *Dev, channel uint8, state State) { if actual, _ := dev.State(channel); actual != state { msg := fmt.Sprintf("Channel %d should have state %s, got: %s", channel, state, actual)